Add packed_affine_qmm_t

This commit is contained in:
Angelos Katharopoulos 2024-12-16 21:49:14 -08:00
parent 410ccdbed5
commit fb7be036af
4 changed files with 455 additions and 21 deletions

View File

@ -1,19 +1,19 @@
import argparse
import math
from functools import partial
import mlx.core as mx
from time_utils import time_fn
B = 1024
D = 1024
M = 4 * D
group_size = 64
bits = 4
dtype = mx.float16
loops = 100
loops = 10
def qmv_(x, wq1, wq2, q_type):
def qmm_(x, wq1, wq2, q_type):
for i in range(loops):
x = mx.quantized_matmul(
x,
@ -32,28 +32,28 @@ def qmv_(x, wq1, wq2, q_type):
return x
def affine_qmv(x, wq1, wq2):
return qmv_(x, wq1, wq2, "affine")
def affine_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine")
def affine_packed_qmv(x, wq1, wq2):
return qmv_(x, wq1, wq2, "affine-packed")
def affine_packed_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine-packed")
def time_qmv():
def time_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmv, x, wq1, wq2)
time_fn(affine_qmm, x, wq1, wq2)
def time_packed_qmv():
def time_packed_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
@ -63,12 +63,12 @@ def time_packed_qmv():
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmv, x, wq1, wq2)
time_fn(affine_packed_qmm, x, wq1, wq2)
if __name__ == "__main__":
for b in [2, 4, 8]:
bits = b
print(f"Bits {bits}:")
time_qmv()
time_packed_qmv()
time_qmm()
time_packed_qmm()

View File

@ -1248,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device T*& scales,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
} else {
ulong2 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
}
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
@ -2266,11 +2301,322 @@ template <typename T, int group_size, int bits>
const device vec<T, 4>* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
affine_packed_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
short bits>
struct AffinePackedQuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short row_pack_factor = 4;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor;
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED;
MLX_MTL_CONST short n_reads =
(TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
static_assert(
n_reads <= row_pack_factor,
"The loader only supports per thread reads <= row_pack_factor");
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
const short bii;
const short bjj;
const device uint32_t* src;
const device T* scales;
const device T* biases;
threadgroup T* dst;
AffinePackedQuantizedBlockLoader(
const device uint32_t* src_,
const device T* scales_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld),
group_step_cnt(0),
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
bii(bi * row_pack_factor + bj % row_pack_factor),
bjj(bj / row_pack_factor),
src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj),
scales(
scales_ + bi * 2 * src_ld * row_pack_factor / group_size +
bj % row_pack_factor),
biases(scales + row_pack_factor),
dst(dst_ + bii * dst_ld + bjj * pack_factor) {}
void load_unsafe() const {
if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
}
}
void load_safe(short2 src_tile_dim) const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales += 8;
biases += 8;
}
} else {
scales += 8;
biases += 8;
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void affine_packed_qmm_t_impl(
const device uint32_t* w,
const device T* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
(void)lid;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = 32 / bits;
constexpr int row_pack_factor = 4;
constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = AffinePackedQuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
// Set the block
const int K_w = K * row_pack_factor / pack_factor;
const int K_g = K * 2 * row_pack_factor / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const int packed_y_col = tid.x * (BN / row_pack_factor);
x += y_row * K;
w += packed_y_col * K_w;
scales += packed_y_col * K_g;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
} else {
mma_op.store_result(y, N);
}
}
template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void affine_packed_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& K [[buffer(4)]],
const constant int& N [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}

View File

@ -104,8 +104,12 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits)
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \

View File

@ -428,8 +428,91 @@ void affine_packed_qmv(
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(x, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(D, 5);
compute_encoder.set_bytes(O, 6);
compute_encoder.set_bytes(D, 4);
compute_encoder.set_bytes(O, 5);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void affine_packed_qmm_t(
const std::vector<array>& inputs,
array& out,
bool batched,
int B,
int D,
int O,
int group_size,
int bits,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& d = metal::device(s.device);
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
d.add_temporary(arr_copy, s.index);
return arr_copy;
}
};
// TODO: Deal with this in routing towards qmm_n instead of qmm_t
auto x = ensure_row_contiguous_last_dims(inputs[0]);
auto w = ensure_row_contiguous_last_dims(inputs[1]);
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
int x_batch_ndims = x.ndim() - 2;
auto& x_shape = x.shape();
auto& x_strides = x.strides();
int w_batch_ndims = w.ndim() - 2;
auto& w_shape = w.shape();
auto& w_strides = w.strides();
auto& s_strides = scales.strides();
const int wn = 2;
const int wm = 2;
const int bm = 32;
const int bn = 32;
const int N = (batched) ? out.size() / B / O : 1;
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((O + bn - 1) / bn, (B + bm - 1) / bm, N);
std::string name;
name.reserve(64);
concatenate(
name,
"affine_packed_qmm_t_",
get_type_string(out.dtype()),
"_gs_",
std::to_string(group_size),
"_b_",
std::to_string(bits),
"_alN_",
((O % 32) == 0) ? "true" : "false",
"_batch_",
(batched) ? "true" : "false");
auto kernel = get_quantized_kernel(d, name, "");
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(x, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(D, 4);
compute_encoder.set_bytes(O, 5);
compute_encoder.set_bytes(B, 6);
if (batched) {
compute_encoder.set_bytes(x_batch_ndims, 7);
compute_encoder.set_vector_bytes(x_shape, 8);
compute_encoder.set_vector_bytes(x_strides, 9);
compute_encoder.set_bytes(w_batch_ndims, 10);
compute_encoder.set_vector_bytes(w_shape, 11);
compute_encoder.set_vector_bytes(w_strides, 12);
compute_encoder.set_vector_bytes(s_strides, 13);
}
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -451,6 +534,7 @@ void affine_packed_qmm_op(
if (B < 6) {
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
} else {
affine_packed_qmm_t(inputs, out, batched, B, D, O, group_size, bits, s);
}
} else {
}