mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims * fix test * batched cpu * add batched template param * refactor metal quantized.cpp
This commit is contained in:
parent
58a855682c
commit
d15fa13daf
@ -18,13 +18,19 @@ void _qmm_t_4_64(
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
int K,
|
||||
int B,
|
||||
bool batched_w) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
int w_els = N * K / pack_factor;
|
||||
int g_els = w_els * pack_factor / group_size;
|
||||
|
||||
for (int i = 0; i < B; i++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
@ -62,6 +68,12 @@ void _qmm_t_4_64(
|
||||
|
||||
x += K;
|
||||
}
|
||||
if (batched_w) {
|
||||
w += w_els;
|
||||
scales += g_els;
|
||||
biases += g_els;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
int B = x.size() / K / M;
|
||||
bool batched_w = w.ndim() > 2;
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
K,
|
||||
B,
|
||||
batched_w);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
@ -201,17 +201,22 @@ void _qmm_dispatch(
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
|
||||
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -221,11 +226,11 @@ void _qmm_dispatch(
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>(),
|
||||
x.data<float16_t>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float16_t>(),
|
||||
biases.data<float16_t>(),
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -235,11 +240,11 @@ void _qmm_dispatch(
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>(),
|
||||
x.data<bfloat16_t>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<bfloat16_t>(),
|
||||
biases.data<bfloat16_t>(),
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@ -252,6 +257,7 @@ void _qmm_dispatch(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
array& out,
|
||||
|
@ -8,6 +8,7 @@ using namespace metal;
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
MLX_MTL_CONST int QUAD_SIZE = 4;
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
@ -371,6 +372,64 @@ struct QuantizedBlockLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int group_size, int bits, int D>
|
||||
METAL_FUNC void qmv_quad_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int values_per_thread = D / QUAD_SIZE;
|
||||
constexpr int packs_per_thread = values_per_thread / pack_factor;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
constexpr int results_per_quadgroup = 8;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_quadgroup] = {0};
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
|
||||
|
||||
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + quad_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_quadgroup; row++) {
|
||||
const device uint8_t* wl =
|
||||
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_quadgroup; row++) {
|
||||
result[row] = quad_sum(result[row]);
|
||||
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
|
||||
y[row * quads_per_simd] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void qmv_fast_impl(
|
||||
const device uint32_t* w,
|
||||
@ -586,10 +645,10 @@ METAL_FUNC void qmv_impl(
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
METAL_FUNC void qvm_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
@ -697,16 +756,16 @@ template <
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void qmm_t_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
threadgroup T* Ws,
|
||||
const constant int& M,
|
||||
const constant int& N,
|
||||
const constant int& K,
|
||||
const constant int& N,
|
||||
const constant int& M,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -818,16 +877,16 @@ template <
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void qmm_n_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
threadgroup T* Ws,
|
||||
const constant int& M,
|
||||
const constant int& N,
|
||||
const constant int& K,
|
||||
const constant int& N,
|
||||
const constant int& M,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -942,6 +1001,45 @@ METAL_FUNC void qmm_n_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device T*& scales,
|
||||
const device T*& biases,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant size_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant size_t* w_strides,
|
||||
const constant size_t* s_strides,
|
||||
const constant size_t* b_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx = tid.z;
|
||||
uint32_t w_idx = tid.z;
|
||||
if (x_batch_ndims == 1) {
|
||||
x += x_idx * x_strides[0];
|
||||
} else {
|
||||
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
||||
}
|
||||
if (w_batch_ndims == 1) {
|
||||
w += w_idx * w_strides[0];
|
||||
scales += w_idx * s_strides[0];
|
||||
biases += w_idx * b_strides[0];
|
||||
} else {
|
||||
ulong3 idx = elem_to_loc_broadcast(
|
||||
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
|
||||
w += idx.x;
|
||||
scales += idx.y;
|
||||
biases += idx.z;
|
||||
}
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
@ -996,7 +1094,58 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <typename T, int group_size, int bits, int D, bool batched>
|
||||
[[kernel]] void qmv_quad(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
qmv_quad_impl<T, group_size, bits, D>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
quad_gid,
|
||||
quad_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1005,9 +1154,35 @@ template <typename T, int group_size, int bits>
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
@ -1021,7 +1196,7 @@ template <typename T, int group_size, int bits>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1030,9 +1205,35 @@ template <typename T, const int group_size, const int bits>
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
qmv_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
@ -1046,24 +1247,50 @@ template <typename T, const int group_size, const int bits>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
qvm_impl<T, group_size, bits>(
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
qvm_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
@ -1076,18 +1303,27 @@ template <
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1099,26 +1335,53 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
M * N,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1131,8 +1394,27 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
M * N,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
|
||||
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
@ -1141,23 +1423,23 @@ template <typename T, int group_size, int bits>
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& batch_ndims [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* lhs_strides [[buffer(11)]],
|
||||
const constant size_t* rhs_strides [[buffer(12)]],
|
||||
const constant int& x_batch_ndims [[buffer(13)]],
|
||||
const constant int* x_shape [[buffer(14)]],
|
||||
const constant size_t* x_strides [[buffer(15)]],
|
||||
const constant int& w_batch_ndims [[buffer(16)]],
|
||||
const constant int* w_shape [[buffer(17)]],
|
||||
const constant size_t* w_strides [[buffer(18)]],
|
||||
const constant size_t* s_strides [[buffer(19)]],
|
||||
const constant size_t* b_strides [[buffer(20)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1202,23 +1484,23 @@ template <typename T, int group_size, int bits>
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& batch_ndims [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* lhs_strides [[buffer(11)]],
|
||||
const constant size_t* rhs_strides [[buffer(12)]],
|
||||
const constant int& x_batch_ndims [[buffer(13)]],
|
||||
const constant int* x_shape [[buffer(14)]],
|
||||
const constant size_t* x_strides [[buffer(15)]],
|
||||
const constant int& w_batch_ndims [[buffer(16)]],
|
||||
const constant int* w_shape [[buffer(17)]],
|
||||
const constant size_t* w_strides [[buffer(18)]],
|
||||
const constant size_t* s_strides [[buffer(19)]],
|
||||
const constant size_t* b_strides [[buffer(20)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1259,27 +1541,27 @@ template <typename T, int group_size, int bits>
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& batch_ndims [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* lhs_strides [[buffer(11)]],
|
||||
const constant size_t* rhs_strides [[buffer(12)]],
|
||||
const constant int& x_batch_ndims [[buffer(13)]],
|
||||
const constant int* x_shape [[buffer(14)]],
|
||||
const constant size_t* x_strides [[buffer(15)]],
|
||||
const constant int& w_batch_ndims [[buffer(16)]],
|
||||
const constant int* w_shape [[buffer(17)]],
|
||||
const constant size_t* w_strides [[buffer(18)]],
|
||||
const constant size_t* s_strides [[buffer(19)]],
|
||||
const constant size_t* b_strides [[buffer(20)]],
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int& batch_ndims [[buffer(15)]],
|
||||
const constant int* batch_shape [[buffer(16)]],
|
||||
const device uint32_t* lhs_indices [[buffer(17)]],
|
||||
const device uint32_t* rhs_indices [[buffer(18)]],
|
||||
const constant size_t* lhs_strides [[buffer(19)]],
|
||||
const constant size_t* rhs_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@ -1306,10 +1588,10 @@ template <typename T, int group_size, int bits>
|
||||
b_strides,
|
||||
tid);
|
||||
qvm_impl<T, group_size, bits>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
@ -1327,28 +1609,28 @@ template <
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void bs_qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& K [[buffer(9)]],
|
||||
const constant int& batch_ndims [[buffer(10)]],
|
||||
const constant int* batch_shape [[buffer(11)]],
|
||||
const constant size_t* lhs_strides [[buffer(12)]],
|
||||
const constant size_t* rhs_strides [[buffer(13)]],
|
||||
const constant int& x_batch_ndims [[buffer(14)]],
|
||||
const constant int* x_shape [[buffer(15)]],
|
||||
const constant size_t* x_strides [[buffer(16)]],
|
||||
const constant int& w_batch_ndims [[buffer(17)]],
|
||||
const constant int* w_shape [[buffer(18)]],
|
||||
const constant size_t* w_strides [[buffer(19)]],
|
||||
const constant size_t* s_strides [[buffer(20)]],
|
||||
const constant size_t* b_strides [[buffer(21)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const constant size_t* lhs_strides [[buffer(20)]],
|
||||
const constant size_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1383,7 +1665,7 @@ template <
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
@ -1394,28 +1676,28 @@ template <
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void bs_qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& K [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& K [[buffer(9)]],
|
||||
const constant int& batch_ndims [[buffer(10)]],
|
||||
const constant int* batch_shape [[buffer(11)]],
|
||||
const constant size_t* lhs_strides [[buffer(12)]],
|
||||
const constant size_t* rhs_strides [[buffer(13)]],
|
||||
const constant int& x_batch_ndims [[buffer(14)]],
|
||||
const constant int* x_shape [[buffer(15)]],
|
||||
const constant size_t* x_strides [[buffer(16)]],
|
||||
const constant int& w_batch_ndims [[buffer(17)]],
|
||||
const constant int* w_shape [[buffer(18)]],
|
||||
const constant size_t* w_strides [[buffer(19)]],
|
||||
const constant size_t* s_strides [[buffer(20)]],
|
||||
const constant size_t* b_strides [[buffer(21)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant size_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant size_t* w_strides [[buffer(13)]],
|
||||
const constant size_t* s_strides [[buffer(14)]],
|
||||
const constant size_t* b_strides [[buffer(15)]],
|
||||
const constant int& batch_ndims [[buffer(16)]],
|
||||
const constant int* batch_shape [[buffer(17)]],
|
||||
const device uint32_t* lhs_indices [[buffer(18)]],
|
||||
const device uint32_t* rhs_indices [[buffer(19)]],
|
||||
const constant size_t* lhs_strides [[buffer(20)]],
|
||||
const constant size_t* rhs_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -1451,7 +1733,7 @@ template <
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
|
@ -13,32 +13,14 @@
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_quantized_types(name, group_size, bits) \
|
||||
instantiate_quantized(name, float, group_size, bits) \
|
||||
instantiate_quantized(name, float16_t, group_size, bits) \
|
||||
instantiate_quantized(name, bfloat16_t, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_groups(name, bits) \
|
||||
instantiate_quantized_types(name, 128, bits) \
|
||||
instantiate_quantized_types(name, 64, bits) \
|
||||
instantiate_quantized_types(name, 32, bits)
|
||||
|
||||
#define instantiate_quantized_all(name) \
|
||||
instantiate_quantized_groups(name, 2) \
|
||||
instantiate_quantized_groups(name, 4) \
|
||||
instantiate_quantized_groups(name, 8)
|
||||
|
||||
instantiate_quantized_all(qmv_fast)
|
||||
instantiate_quantized_all(qmv)
|
||||
instantiate_quantized_all(qvm)
|
||||
instantiate_quantized_all(qmm_n)
|
||||
instantiate_quantized_all(bs_qmv_fast)
|
||||
instantiate_quantized_all(bs_qmv)
|
||||
instantiate_quantized_all(bs_qvm)
|
||||
instantiate_quantized_all(bs_qmm_n)
|
||||
instantiate_quantized_all(affine_quantize)
|
||||
instantiate_quantized_all(affine_quantize_scales_biases)
|
||||
instantiate_quantized_all(affine_dequantize)
|
||||
#define instantiate_quantized_batched(name, type, group_size, bits, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
|
||||
instantiate_kernel( \
|
||||
@ -49,23 +31,78 @@ instantiate_quantized_all(affine_dequantize)
|
||||
bits, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_types_aligned(name, group_size, bits) \
|
||||
instantiate_quantized_aligned(name, float, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, float, group_size, bits, false) \
|
||||
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
|
||||
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
|
||||
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_groups_aligned(name, bits) \
|
||||
instantiate_quantized_types_aligned(name, 128, bits) \
|
||||
instantiate_quantized_types_aligned(name, 64, bits) \
|
||||
instantiate_quantized_types_aligned(name, 32, bits)
|
||||
#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_all_aligned(name) \
|
||||
instantiate_quantized_groups_aligned(name, 2) \
|
||||
instantiate_quantized_groups_aligned(name, 4) \
|
||||
instantiate_quantized_groups_aligned(name, 8) \
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
|
||||
instantiate_quantized_all_aligned(qmm_t)
|
||||
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
|
||||
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
||||
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
|
||||
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmv, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qvm, type, group_size, bits) \
|
||||
instantiate_quantized(bs_qmm_n, type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
instantiate_quantized_funcs(float16_t, group_size, bits) \
|
||||
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_groups(bits) \
|
||||
instantiate_quantized_types(128, bits) \
|
||||
instantiate_quantized_types(64, bits) \
|
||||
instantiate_quantized_types(32, bits)
|
||||
|
||||
#define instantiate_quantized_all() \
|
||||
instantiate_quantized_groups(2) \
|
||||
instantiate_quantized_groups(4) \
|
||||
instantiate_quantized_groups(8)
|
||||
|
||||
instantiate_quantized_all() // clang-format on
|
||||
|
@ -12,231 +12,29 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
void launch_qmm(
|
||||
std::string name,
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
int D,
|
||||
int O,
|
||||
int B,
|
||||
int N,
|
||||
MTL::Size& group_dims,
|
||||
MTL::Size& grid_dims,
|
||||
bool batched,
|
||||
bool matrix,
|
||||
bool gather,
|
||||
bool aligned,
|
||||
bool quad,
|
||||
const Stream& s) {
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
int D = x.shape(-1);
|
||||
int B = x.size() / D;
|
||||
int O = out.shape(-1);
|
||||
if (transpose_) {
|
||||
// Route to the fast qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv_fast", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmv kernel
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_t kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << aligned_n;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
} else {
|
||||
// Route to the qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qvm", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_n kernel
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_n", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
if ((O % bn) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The output size should be divisible by "
|
||||
<< bn << " but received " << O << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
||||
copies.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
// TODO: collapse batch dims
|
||||
auto& batch_shape = lhs_indices.shape();
|
||||
int batch_ndims = batch_shape.size();
|
||||
auto& lhs_strides = lhs_indices.strides();
|
||||
auto& rhs_strides = rhs_indices.strides();
|
||||
|
||||
// Ensure that the last two dims are row contiguous.
|
||||
// TODO: Check if we really need this for x as well...
|
||||
std::vector<array> copies;
|
||||
@ -266,256 +64,205 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s_strides = scales.strides();
|
||||
auto& b_strides = biases.strides();
|
||||
|
||||
int D = x.shape(-1);
|
||||
int B = x.shape(-2);
|
||||
int O = out.shape(-1);
|
||||
int N = out.size() / B / O;
|
||||
if (transpose_) {
|
||||
// Route to the fast bs_qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmv", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the bs_qmm_t
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << aligned_n;
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
|
||||
if (quad) {
|
||||
kname << "_d_" << D;
|
||||
}
|
||||
if (aligned) {
|
||||
kname << "_alN_" << aligned_n;
|
||||
}
|
||||
if (!gather) {
|
||||
kname << "_batch_" << batched;
|
||||
}
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
|
||||
std::string template_def;
|
||||
if (quad) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, D, batched);
|
||||
} else if (aligned && !gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
|
||||
} else if (!gather && !aligned) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, batched);
|
||||
} else if (aligned && gather) {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits, aligned_n);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname.str(), name, type_string, group_size, bits);
|
||||
}
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
int offset = 7;
|
||||
if (matrix) {
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
offset += 1;
|
||||
}
|
||||
|
||||
if (batched || gather) {
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
|
||||
set_vector_bytes(compute_encoder, x_shape, offset + 1);
|
||||
set_vector_bytes(compute_encoder, x_strides, offset + 2);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
|
||||
set_vector_bytes(compute_encoder, w_shape, offset + 4);
|
||||
set_vector_bytes(compute_encoder, w_strides, offset + 5);
|
||||
set_vector_bytes(compute_encoder, s_strides, offset + 6);
|
||||
set_vector_bytes(compute_encoder, b_strides, offset + 7);
|
||||
}
|
||||
if (gather) {
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
// TODO: collapse batch dims
|
||||
auto& batch_shape = lhs_indices.shape();
|
||||
int batch_ndims = batch_shape.size();
|
||||
auto& lhs_strides = lhs_indices.strides();
|
||||
auto& rhs_strides = rhs_indices.strides();
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
|
||||
set_vector_bytes(compute_encoder, batch_shape, offset + 9);
|
||||
compute_encoder.set_input_array(lhs_indices, offset + 10);
|
||||
compute_encoder.set_input_array(rhs_indices, offset + 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
|
||||
}
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void qmm_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool gather,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous);
|
||||
|
||||
int D = x.shape(-1);
|
||||
int O = out.shape(-1);
|
||||
// For the unbatched W case, avoid `adjust_matrix_offsets`
|
||||
// for a small performance gain.
|
||||
int B = (batched || gather) ? x.shape(-2) : x.size() / D;
|
||||
int N = (batched || gather) ? out.size() / B / O : 1;
|
||||
|
||||
std::string name = gather ? "bs_" : "";
|
||||
bool matrix = false;
|
||||
bool aligned = false;
|
||||
bool quad = false;
|
||||
|
||||
if (transpose) {
|
||||
if (B < 6 && (D == 128 || D == 64)) {
|
||||
name += "qmv_quad";
|
||||
constexpr int quads_per_simd = 8;
|
||||
constexpr int results_per_quadgroup = 8;
|
||||
int bo = quads_per_simd * results_per_quadgroup;
|
||||
int simdgroup_size = 32;
|
||||
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
||||
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
quad = true;
|
||||
} else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
name += "qmv_fast";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size(O / bo, B, N);
|
||||
} else if (B < 6) {
|
||||
name += "qmv";
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
} else {
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
group_dims = MTL::Size(32, wn, wm);
|
||||
grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||
name += "qmm_t";
|
||||
matrix = true;
|
||||
aligned = true;
|
||||
}
|
||||
} else {
|
||||
// Route to the bs_qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qvm", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
name += "qvm";
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to bs_qmm_n
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(out.dtype());
|
||||
kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
group_dims = MTL::Size(bd, 2, 1);
|
||||
grid_dims = MTL::Size(O / bo, B, N);
|
||||
} else {
|
||||
name += "qmm_n";
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
group_dims = MTL::Size(32, wn, wm);
|
||||
grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
|
||||
matrix = true;
|
||||
if ((O % bn) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The output size should be divisible by "
|
||||
<< bn << " but received " << O << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
launch_qmm(
|
||||
name,
|
||||
inputs,
|
||||
out,
|
||||
group_size,
|
||||
bits,
|
||||
D,
|
||||
O,
|
||||
B,
|
||||
N,
|
||||
group_dims,
|
||||
grid_dims,
|
||||
batched,
|
||||
matrix,
|
||||
gather,
|
||||
aligned,
|
||||
quad,
|
||||
s);
|
||||
}
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
|
||||
}
|
||||
|
||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
|
@ -725,15 +725,6 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
|
||||
int el_per_int = 32 / bits;
|
||||
|
||||
if (w.shape(-1) < 32 * el_per_int) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
||||
<< "too small for quantization. We support >=512 for 2 bits, "
|
||||
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
|
||||
<< "shape " << w.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback = [group_size, bits, el_per_int, s](
|
||||
const std::vector<array>& inputs) -> std::vector<array> {
|
||||
auto& w = inputs[0];
|
||||
|
34
mlx/ops.cpp
34
mlx/ops.cpp
@ -3592,10 +3592,10 @@ array conv_general(
|
||||
}
|
||||
|
||||
array quantized_matmul(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
@ -3604,11 +3604,27 @@ array quantized_matmul(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
if (w.ndim() != 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
|
||||
<< "received w with shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
std::vector<int> bsx_x(x.shape().begin(), x.shape().end() - 2);
|
||||
std::vector<int> bsx_w(w.shape().begin(), w.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
|
||||
|
||||
// Broadcast x
|
||||
inner_shape.push_back(x.shape(-2));
|
||||
inner_shape.push_back(x.shape(-1));
|
||||
x = broadcast_to(x, inner_shape, s);
|
||||
|
||||
// Broadcast w
|
||||
*(inner_shape.end() - 2) = w.shape(-2);
|
||||
*(inner_shape.end() - 1) = w.shape(-1);
|
||||
w = broadcast_to(w, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
||||
scales = broadcast_to(scales, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
||||
biases = broadcast_to(biases, inner_shape, s);
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
|
@ -1287,10 +1287,10 @@ array conv_transpose3d(
|
||||
|
||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||
array quantized_matmul(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
|
@ -117,19 +117,24 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[512, 1024, 67], # M
|
||||
[64, 128, 512, 1024], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
if group_size > N:
|
||||
continue
|
||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (M, N) if B == 0 else (B, M, N)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, True, group_size, bits
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
@ -140,12 +145,15 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[512, 1024, 67], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(N, M), key=k2)
|
||||
for group_size, bits, M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(
|
||||
@ -172,33 +180,35 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
mx.eval(y)
|
||||
|
||||
def test_small_matrix(self):
|
||||
w = mx.random.normal(shape=(8, 256))
|
||||
for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]:
|
||||
with self.subTest(w_shape=w_shape):
|
||||
w = mx.random.normal(shape=(w_shape))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
x = mx.random.normal(shape=(3, 1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
x = mx.random.normal(shape=(3, 10, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 8))
|
||||
# Test qvm
|
||||
x = mx.random.normal(shape=(3, 1, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 8))
|
||||
x = mx.random.normal(shape=(3, 10, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
|
Loading…
Reference in New Issue
Block a user