mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-18 08:36:43 +08:00
Add packed_affine_qmm_t
This commit is contained in:
parent
410ccdbed5
commit
fb7be036af
@ -1,19 +1,19 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
B = 1024
|
||||||
D = 1024
|
D = 1024
|
||||||
M = 4 * D
|
M = 4 * D
|
||||||
group_size = 64
|
group_size = 64
|
||||||
bits = 4
|
bits = 4
|
||||||
dtype = mx.float16
|
dtype = mx.float16
|
||||||
loops = 100
|
loops = 10
|
||||||
|
|
||||||
|
|
||||||
def qmv_(x, wq1, wq2, q_type):
|
def qmm_(x, wq1, wq2, q_type):
|
||||||
for i in range(loops):
|
for i in range(loops):
|
||||||
x = mx.quantized_matmul(
|
x = mx.quantized_matmul(
|
||||||
x,
|
x,
|
||||||
@ -32,28 +32,28 @@ def qmv_(x, wq1, wq2, q_type):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def affine_qmv(x, wq1, wq2):
|
def affine_qmm(x, wq1, wq2):
|
||||||
return qmv_(x, wq1, wq2, "affine")
|
return qmm_(x, wq1, wq2, "affine")
|
||||||
|
|
||||||
|
|
||||||
def affine_packed_qmv(x, wq1, wq2):
|
def affine_packed_qmm(x, wq1, wq2):
|
||||||
return qmv_(x, wq1, wq2, "affine-packed")
|
return qmm_(x, wq1, wq2, "affine-packed")
|
||||||
|
|
||||||
|
|
||||||
def time_qmv():
|
def time_qmm():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||||
mx.eval(x, wq1, wq2)
|
mx.eval(x, wq1, wq2)
|
||||||
time_fn(affine_qmv, x, wq1, wq2)
|
time_fn(affine_qmm, x, wq1, wq2)
|
||||||
|
|
||||||
|
|
||||||
def time_packed_qmv():
|
def time_packed_qmm():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||||
wq1 = mx.quantize(
|
wq1 = mx.quantize(
|
||||||
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||||
@ -63,12 +63,12 @@ def time_packed_qmv():
|
|||||||
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||||
)
|
)
|
||||||
mx.eval(x, wq1, wq2)
|
mx.eval(x, wq1, wq2)
|
||||||
time_fn(affine_packed_qmv, x, wq1, wq2)
|
time_fn(affine_packed_qmm, x, wq1, wq2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for b in [2, 4, 8]:
|
for b in [2, 4, 8]:
|
||||||
bits = b
|
bits = b
|
||||||
print(f"Bits {bits}:")
|
print(f"Bits {bits}:")
|
||||||
time_qmv()
|
time_qmm()
|
||||||
time_packed_qmv()
|
time_packed_qmm()
|
@ -1248,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets(
|
|||||||
y += tid.z * output_stride;
|
y += tid.z * output_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC void adjust_matrix_offsets(
|
||||||
|
const device T*& x,
|
||||||
|
const device uint32_t*& w,
|
||||||
|
const device T*& scales,
|
||||||
|
device T*& y,
|
||||||
|
int output_stride,
|
||||||
|
const constant int& x_batch_ndims,
|
||||||
|
const constant int* x_shape,
|
||||||
|
const constant int64_t* x_strides,
|
||||||
|
const constant int& w_batch_ndims,
|
||||||
|
const constant int* w_shape,
|
||||||
|
const constant int64_t* w_strides,
|
||||||
|
const constant int64_t* s_strides,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
// Set the input/output matrices
|
||||||
|
uint32_t x_idx = tid.z;
|
||||||
|
uint32_t w_idx = tid.z;
|
||||||
|
if (x_batch_ndims == 1) {
|
||||||
|
x += x_idx * x_strides[0];
|
||||||
|
} else {
|
||||||
|
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
||||||
|
}
|
||||||
|
if (w_batch_ndims == 1) {
|
||||||
|
w += w_idx * w_strides[0];
|
||||||
|
scales += w_idx * s_strides[0];
|
||||||
|
} else {
|
||||||
|
ulong2 idx = elem_to_loc_broadcast(
|
||||||
|
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
|
||||||
|
w += idx.x;
|
||||||
|
scales += idx.y;
|
||||||
|
}
|
||||||
|
y += tid.z * output_stride;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
METAL_FUNC void adjust_matrix_offsets(
|
METAL_FUNC void adjust_matrix_offsets(
|
||||||
const device T*& x,
|
const device T*& x,
|
||||||
@ -2266,11 +2301,322 @@ template <typename T, int group_size, int bits>
|
|||||||
const device vec<T, 4>* scales [[buffer(1)]],
|
const device vec<T, 4>* scales [[buffer(1)]],
|
||||||
const device T* x [[buffer(2)]],
|
const device T* x [[buffer(2)]],
|
||||||
device T* y [[buffer(3)]],
|
device T* y [[buffer(3)]],
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const constant int& in_vec_size [[buffer(4)]],
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
const constant int& out_vec_size [[buffer(5)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
short BROWS,
|
||||||
|
short BCOLS,
|
||||||
|
short dst_ld,
|
||||||
|
short reduction_dim,
|
||||||
|
short tgp_size,
|
||||||
|
short group_size,
|
||||||
|
short bits>
|
||||||
|
struct AffinePackedQuantizedBlockLoader {
|
||||||
|
static_assert(
|
||||||
|
BCOLS <= group_size,
|
||||||
|
"The group size should be larger than the columns");
|
||||||
|
static_assert(
|
||||||
|
group_size % BCOLS == 0,
|
||||||
|
"The group size should be divisible by the columns");
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||||
|
|
||||||
|
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||||
|
MLX_MTL_CONST short row_pack_factor = 4;
|
||||||
|
MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor;
|
||||||
|
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
|
||||||
|
MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED;
|
||||||
|
MLX_MTL_CONST short n_reads =
|
||||||
|
(TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size;
|
||||||
|
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
n_reads <= row_pack_factor,
|
||||||
|
"The loader only supports per thread reads <= row_pack_factor");
|
||||||
|
|
||||||
|
const int src_ld;
|
||||||
|
const int tile_stride;
|
||||||
|
short group_step_cnt;
|
||||||
|
const int group_stride;
|
||||||
|
|
||||||
|
const short thread_idx;
|
||||||
|
const short bi;
|
||||||
|
const short bj;
|
||||||
|
const short bii;
|
||||||
|
const short bjj;
|
||||||
|
|
||||||
|
const device uint32_t* src;
|
||||||
|
const device T* scales;
|
||||||
|
const device T* biases;
|
||||||
|
threadgroup T* dst;
|
||||||
|
|
||||||
|
AffinePackedQuantizedBlockLoader(
|
||||||
|
const device uint32_t* src_,
|
||||||
|
const device T* scales_,
|
||||||
|
const int src_ld_,
|
||||||
|
threadgroup T* dst_,
|
||||||
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: src_ld(src_ld_),
|
||||||
|
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld),
|
||||||
|
group_step_cnt(0),
|
||||||
|
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
|
||||||
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
|
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||||
|
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||||
|
bii(bi * row_pack_factor + bj % row_pack_factor),
|
||||||
|
bjj(bj / row_pack_factor),
|
||||||
|
src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj),
|
||||||
|
scales(
|
||||||
|
scales_ + bi * 2 * src_ld * row_pack_factor / group_size +
|
||||||
|
bj % row_pack_factor),
|
||||||
|
biases(scales + row_pack_factor),
|
||||||
|
dst(dst_ + bii * dst_ld + bjj * pack_factor) {}
|
||||||
|
|
||||||
|
void load_unsafe() const {
|
||||||
|
if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
T scale = scales[i];
|
||||||
|
T bias = biases[i];
|
||||||
|
dequantize<T, pack_factor, bits>(
|
||||||
|
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void load_safe(short2 src_tile_dim) const {
|
||||||
|
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
|
||||||
|
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||||
|
dst[i] = T(0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
|
||||||
|
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||||
|
dst[i] = T(0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_reads; i++) {
|
||||||
|
T scale = scales[i];
|
||||||
|
T bias = biases[i];
|
||||||
|
dequantize<T, pack_factor, bits>(
|
||||||
|
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void next() {
|
||||||
|
src += tile_stride;
|
||||||
|
if (reduction_dim == 1) {
|
||||||
|
if (group_steps > 1) {
|
||||||
|
group_step_cnt++;
|
||||||
|
if (group_step_cnt == group_steps) {
|
||||||
|
group_step_cnt = 0;
|
||||||
|
scales += 8;
|
||||||
|
biases += 8;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scales += 8;
|
||||||
|
biases += 8;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scales += group_stride;
|
||||||
|
biases += group_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const bool aligned_N,
|
||||||
|
const int BM = 32,
|
||||||
|
const int BK = 32,
|
||||||
|
const int BN = 32>
|
||||||
|
METAL_FUNC void affine_packed_qmm_t_impl(
|
||||||
|
const device uint32_t* w,
|
||||||
|
const device T* scales,
|
||||||
|
const device T* x,
|
||||||
|
device T* y,
|
||||||
|
threadgroup T* Xs,
|
||||||
|
threadgroup T* Ws,
|
||||||
|
const constant int& K,
|
||||||
|
const constant int& N,
|
||||||
|
const constant int& M,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||||
|
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr int WM = 2;
|
||||||
|
constexpr int WN = 2;
|
||||||
|
constexpr int pack_factor = 32 / bits;
|
||||||
|
constexpr int row_pack_factor = 4;
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
|
||||||
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
|
using mma_t = mlx::steel::
|
||||||
|
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||||
|
using loader_x_t =
|
||||||
|
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||||
|
using loader_w_t = AffinePackedQuantizedBlockLoader<
|
||||||
|
T,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
BK_padded,
|
||||||
|
1,
|
||||||
|
WM * WN * SIMD_SIZE,
|
||||||
|
group_size,
|
||||||
|
bits>;
|
||||||
|
|
||||||
|
// Set the block
|
||||||
|
const int K_w = K * row_pack_factor / pack_factor;
|
||||||
|
const int K_g = K * 2 * row_pack_factor / group_size;
|
||||||
|
const int y_row = tid.y * BM;
|
||||||
|
const int y_col = tid.x * BN;
|
||||||
|
const int packed_y_col = tid.x * (BN / row_pack_factor);
|
||||||
|
|
||||||
|
x += y_row * K;
|
||||||
|
w += packed_y_col * K_w;
|
||||||
|
scales += packed_y_col * K_g;
|
||||||
|
y += y_row * N + y_col;
|
||||||
|
|
||||||
|
// Make the x loader and mma operation
|
||||||
|
const short num_els = min(BM, M - y_row);
|
||||||
|
const short num_outs = min(BN, N - y_col);
|
||||||
|
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||||
|
loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid);
|
||||||
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
|
if (num_els < BM) {
|
||||||
|
if (!aligned_N && num_outs < BN) {
|
||||||
|
for (int k = 0; k < K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
|
loader_w.load_safe(short2(BK, num_outs));
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int k = 0; k < K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_x.load_safe(short2(BK, num_els));
|
||||||
|
loader_w.load_unsafe();
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!aligned_N && num_outs < BN) {
|
||||||
|
for (int k = 0; k < K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_x.load_unsafe();
|
||||||
|
loader_w.load_safe(short2(BK, num_outs));
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int k = 0; k < K; k += BK) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
loader_x.load_unsafe();
|
||||||
|
loader_w.load_unsafe();
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(Xs, Ws);
|
||||||
|
loader_x.next();
|
||||||
|
loader_w.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (num_els < BM || num_outs < BN) {
|
||||||
|
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
||||||
|
} else {
|
||||||
|
mma_op.store_result(y, N);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const bool aligned_N,
|
||||||
|
const bool batched,
|
||||||
|
const int BM = 32,
|
||||||
|
const int BK = 32,
|
||||||
|
const int BN = 32>
|
||||||
|
[[kernel]] void affine_packed_qmm_t(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* x [[buffer(2)]],
|
||||||
|
device T* y [[buffer(3)]],
|
||||||
|
const constant int& K [[buffer(4)]],
|
||||||
|
const constant int& N [[buffer(5)]],
|
||||||
|
const constant int& M [[buffer(6)]],
|
||||||
|
const constant int& x_batch_ndims [[buffer(7)]],
|
||||||
|
const constant int* x_shape [[buffer(8)]],
|
||||||
|
const constant int64_t* x_strides [[buffer(9)]],
|
||||||
|
const constant int& w_batch_ndims [[buffer(10)]],
|
||||||
|
const constant int* w_shape [[buffer(11)]],
|
||||||
|
const constant int64_t* w_strides [[buffer(12)]],
|
||||||
|
const constant int64_t* s_strides [[buffer(13)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
|
||||||
|
threadgroup T Xs[BM * BK_padded];
|
||||||
|
threadgroup T Ws[BN * BK_padded];
|
||||||
|
|
||||||
|
if (batched) {
|
||||||
|
adjust_matrix_offsets<T>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
y,
|
||||||
|
M * N,
|
||||||
|
x_batch_ndims,
|
||||||
|
x_shape,
|
||||||
|
x_strides,
|
||||||
|
w_batch_ndims,
|
||||||
|
w_shape,
|
||||||
|
w_strides,
|
||||||
|
s_strides,
|
||||||
|
tid);
|
||||||
|
}
|
||||||
|
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||||
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||||
|
}
|
||||||
|
@ -104,8 +104,12 @@
|
|||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||||
|
|
||||||
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
|
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
|
||||||
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits)
|
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \
|
||||||
|
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false)
|
||||||
|
|
||||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||||
instantiate_quantized_all_single(type, group_size, bits) \
|
instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
|
@ -428,8 +428,91 @@ void affine_packed_qmv(
|
|||||||
compute_encoder.set_input_array(scales, 1);
|
compute_encoder.set_input_array(scales, 1);
|
||||||
compute_encoder.set_input_array(x, 2);
|
compute_encoder.set_input_array(x, 2);
|
||||||
compute_encoder.set_output_array(out, 3);
|
compute_encoder.set_output_array(out, 3);
|
||||||
compute_encoder.set_bytes(D, 5);
|
compute_encoder.set_bytes(D, 4);
|
||||||
compute_encoder.set_bytes(O, 6);
|
compute_encoder.set_bytes(O, 5);
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void affine_packed_qmm_t(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
bool batched,
|
||||||
|
int B,
|
||||||
|
int D,
|
||||||
|
int O,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const Stream& s) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
|
||||||
|
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
d.add_temporary(arr_copy, s.index);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// TODO: Deal with this in routing towards qmm_n instead of qmm_t
|
||||||
|
auto x = ensure_row_contiguous_last_dims(inputs[0]);
|
||||||
|
auto w = ensure_row_contiguous_last_dims(inputs[1]);
|
||||||
|
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
|
||||||
|
|
||||||
|
int x_batch_ndims = x.ndim() - 2;
|
||||||
|
auto& x_shape = x.shape();
|
||||||
|
auto& x_strides = x.strides();
|
||||||
|
int w_batch_ndims = w.ndim() - 2;
|
||||||
|
auto& w_shape = w.shape();
|
||||||
|
auto& w_strides = w.strides();
|
||||||
|
auto& s_strides = scales.strides();
|
||||||
|
|
||||||
|
const int wn = 2;
|
||||||
|
const int wm = 2;
|
||||||
|
const int bm = 32;
|
||||||
|
const int bn = 32;
|
||||||
|
const int N = (batched) ? out.size() / B / O : 1;
|
||||||
|
MTL::Size group_dims(32, wn, wm);
|
||||||
|
MTL::Size grid_dims((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||||
|
|
||||||
|
std::string name;
|
||||||
|
name.reserve(64);
|
||||||
|
concatenate(
|
||||||
|
name,
|
||||||
|
"affine_packed_qmm_t_",
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
"_gs_",
|
||||||
|
std::to_string(group_size),
|
||||||
|
"_b_",
|
||||||
|
std::to_string(bits),
|
||||||
|
"_alN_",
|
||||||
|
((O % 32) == 0) ? "true" : "false",
|
||||||
|
"_batch_",
|
||||||
|
(batched) ? "true" : "false");
|
||||||
|
auto kernel = get_quantized_kernel(d, name, "");
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(w, 0);
|
||||||
|
compute_encoder.set_input_array(scales, 1);
|
||||||
|
compute_encoder.set_input_array(x, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
compute_encoder.set_bytes(D, 4);
|
||||||
|
compute_encoder.set_bytes(O, 5);
|
||||||
|
compute_encoder.set_bytes(B, 6);
|
||||||
|
if (batched) {
|
||||||
|
compute_encoder.set_bytes(x_batch_ndims, 7);
|
||||||
|
compute_encoder.set_vector_bytes(x_shape, 8);
|
||||||
|
compute_encoder.set_vector_bytes(x_strides, 9);
|
||||||
|
compute_encoder.set_bytes(w_batch_ndims, 10);
|
||||||
|
compute_encoder.set_vector_bytes(w_shape, 11);
|
||||||
|
compute_encoder.set_vector_bytes(w_strides, 12);
|
||||||
|
compute_encoder.set_vector_bytes(s_strides, 13);
|
||||||
|
}
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -451,6 +534,7 @@ void affine_packed_qmm_op(
|
|||||||
if (B < 6) {
|
if (B < 6) {
|
||||||
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
||||||
} else {
|
} else {
|
||||||
|
affine_packed_qmm_t(inputs, out, batched, B, D, O, group_size, bits, s);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user