mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Block sparse qmm (#1124)
This commit is contained in:

committed by
GitHub

parent
1873ffda01
commit
e78a6518fa
@@ -378,14 +378,14 @@ struct QuantizedBlockLoader {
|
||||
};
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
METAL_FUNC void qmv_fast_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -404,13 +404,13 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
@@ -440,15 +440,15 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void qmv_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -468,7 +468,7 @@ template <typename T, const int group_size, const int bits>
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
||||
|
||||
@@ -482,8 +482,8 @@ template <typename T, const int group_size, const int bits>
|
||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
int k = 0;
|
||||
for (; k < in_vec_size - block_size; k += block_size) {
|
||||
@@ -537,8 +537,8 @@ template <typename T, const int group_size, const int bits>
|
||||
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + used_out_row;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + used_out_row;
|
||||
|
||||
int k = 0;
|
||||
for (; k < in_vec_size - block_size; k += block_size) {
|
||||
@@ -590,14 +590,14 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
METAL_FUNC void qvm_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -616,12 +616,12 @@ template <typename T, const int group_size, const int bits>
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
w += out_col / pack_factor;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
x += tid.y * in_vec_size;
|
||||
y += tid.y * out_vec_size + out_col;
|
||||
|
||||
if (out_col >= out_vec_size) {
|
||||
return;
|
||||
@@ -675,15 +675,17 @@ template <
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
METAL_FUNC void qmm_t_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
threadgroup T* Ws,
|
||||
const constant int& M,
|
||||
const constant int& N,
|
||||
const constant int& K,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -713,9 +715,6 @@ template <
|
||||
group_size,
|
||||
bits>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
// Set the block
|
||||
const int K_w = K / pack_factor;
|
||||
const int K_g = K / group_size;
|
||||
@@ -797,15 +796,17 @@ template <
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits>
|
||||
[[kernel]] void qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
METAL_FUNC void qmm_n_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
threadgroup T* Ws,
|
||||
const constant int& M,
|
||||
const constant int& N,
|
||||
const constant int& K,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@@ -836,9 +837,6 @@ template <
|
||||
group_size,
|
||||
bits>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
|
||||
// Set the block
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
@@ -923,6 +921,518 @@ template <
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device T*& scales,
|
||||
const device T*& biases,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& batch_ndims,
|
||||
const constant int* batch_shape,
|
||||
const constant size_t* lhs_strides,
|
||||
const constant size_t* rhs_strides,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant size_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant size_t* w_strides,
|
||||
const constant size_t* s_strides,
|
||||
const constant size_t* b_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx;
|
||||
uint32_t w_idx;
|
||||
if (batch_ndims == 1) {
|
||||
x_idx = lhs_indices[tid.z * lhs_strides[0]];
|
||||
w_idx = rhs_indices[tid.z * rhs_strides[0]];
|
||||
} else {
|
||||
ulong2 idx = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
|
||||
x_idx = lhs_indices[idx.x];
|
||||
w_idx = rhs_indices[idx.y];
|
||||
}
|
||||
if (x_batch_ndims == 1) {
|
||||
x += x_idx * x_strides[0];
|
||||
} else {
|
||||
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
||||
}
|
||||
if (w_batch_ndims == 1) {
|
||||
w += w_idx * w_strides[0];
|
||||
scales += w_idx * s_strides[0];
|
||||
biases += w_idx * b_strides[0];
|
||||
} else {
|
||||
ulong3 idx = elem_to_loc_broadcast(
|
||||
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
|
||||
w += idx.x;
|
||||
scales += idx.y;
|
||||
biases += idx.z;
|
||||
}
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
qmv_fast_impl<T, group_size, bits, packs_per_thread>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
qmv_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
qvm_impl<T, group_size, bits>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits>
|
||||
[[kernel]] void qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& M [[buffer(5)]],
|
||||
const constant int& N [[buffer(6)]],
|
||||
const constant int& K [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void bs_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& batch_ndims [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* lhs_strides [[buffer(11)]],
|
||||
const constant size_t* rhs_strides [[buffer(12)]],
|
||||
const constant int& x_batch_ndims [[buffer(13)]],
|
||||
const constant int* x_shape [[buffer(14)]],
|
||||
const constant size_t* x_strides [[buffer(15)]],
|
||||
const constant int& w_batch_ndims [[buffer(16)]],
|
||||
const constant int* w_shape [[buffer(17)]],
|
||||
const constant size_t* w_strides [[buffer(18)]],
|
||||
const constant size_t* s_strides [[buffer(19)]],
|
||||
const constant size_t* b_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
rhs_strides,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmv_fast_impl<T, group_size, bits, packs_per_thread>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& batch_ndims [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* lhs_strides [[buffer(11)]],
|
||||
const constant size_t* rhs_strides [[buffer(12)]],
|
||||
const constant int& x_batch_ndims [[buffer(13)]],
|
||||
const constant int* x_shape [[buffer(14)]],
|
||||
const constant size_t* x_strides [[buffer(15)]],
|
||||
const constant int& w_batch_ndims [[buffer(16)]],
|
||||
const constant int* w_shape [[buffer(17)]],
|
||||
const constant size_t* w_strides [[buffer(18)]],
|
||||
const constant size_t* s_strides [[buffer(19)]],
|
||||
const constant size_t* b_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
rhs_strides,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmv_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void bs_qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& in_vec_size [[buffer(7)]],
|
||||
const constant int& out_vec_size [[buffer(8)]],
|
||||
const constant int& batch_ndims [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* lhs_strides [[buffer(11)]],
|
||||
const constant size_t* rhs_strides [[buffer(12)]],
|
||||
const constant int& x_batch_ndims [[buffer(13)]],
|
||||
const constant int* x_shape [[buffer(14)]],
|
||||
const constant size_t* x_strides [[buffer(15)]],
|
||||
const constant int& w_batch_ndims [[buffer(16)]],
|
||||
const constant int* w_shape [[buffer(17)]],
|
||||
const constant size_t* w_strides [[buffer(18)]],
|
||||
const constant size_t* s_strides [[buffer(19)]],
|
||||
const constant size_t* b_strides [[buffer(20)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
rhs_strides,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qvm_impl<T, group_size, bits>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N>
|
||||
[[kernel]] void bs_qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& K [[buffer(9)]],
|
||||
const constant int& batch_ndims [[buffer(10)]],
|
||||
const constant int* batch_shape [[buffer(11)]],
|
||||
const constant size_t* lhs_strides [[buffer(12)]],
|
||||
const constant size_t* rhs_strides [[buffer(13)]],
|
||||
const constant int& x_batch_ndims [[buffer(14)]],
|
||||
const constant int* x_shape [[buffer(15)]],
|
||||
const constant size_t* x_strides [[buffer(16)]],
|
||||
const constant int& w_batch_ndims [[buffer(17)]],
|
||||
const constant int* w_shape [[buffer(18)]],
|
||||
const constant size_t* w_strides [[buffer(19)]],
|
||||
const constant size_t* s_strides [[buffer(20)]],
|
||||
const constant size_t* b_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
M * N,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
rhs_strides,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits>
|
||||
[[kernel]] void bs_qmm_n(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
const device T* biases [[buffer(3)]],
|
||||
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||
device T* y [[buffer(6)]],
|
||||
const constant int& M [[buffer(7)]],
|
||||
const constant int& N [[buffer(8)]],
|
||||
const constant int& K [[buffer(9)]],
|
||||
const constant int& batch_ndims [[buffer(10)]],
|
||||
const constant int* batch_shape [[buffer(11)]],
|
||||
const constant size_t* lhs_strides [[buffer(12)]],
|
||||
const constant size_t* rhs_strides [[buffer(13)]],
|
||||
const constant int& x_batch_ndims [[buffer(14)]],
|
||||
const constant int* x_shape [[buffer(15)]],
|
||||
const constant size_t* x_strides [[buffer(16)]],
|
||||
const constant int& w_batch_ndims [[buffer(17)]],
|
||||
const constant int* w_shape [[buffer(18)]],
|
||||
const constant size_t* w_strides [[buffer(19)]],
|
||||
const constant size_t* s_strides [[buffer(20)]],
|
||||
const constant size_t* b_strides [[buffer(21)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
M * N,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
rhs_strides,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
|
||||
"_fast")]] [[kernel]] void \
|
||||
@@ -1089,3 +1599,241 @@ instantiate_qmm_n_types( 64, 8)
|
||||
instantiate_qmm_n_types( 32, 2)
|
||||
instantiate_qmm_n_types( 32, 4)
|
||||
instantiate_qmm_n_types( 32, 8) // clang-format on
|
||||
|
||||
#define instantiate_bs_qmv_fast( \
|
||||
name, itype, group_size, bits, packs_per_thread) \
|
||||
template [[host_name("bs_qmv_" #name "_gs_" #group_size "_b_" #bits \
|
||||
"_fast")]] [[kernel]] void \
|
||||
bs_qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||
device itype* y [[buffer(6)]], \
|
||||
const constant int& in_vec_size [[buffer(7)]], \
|
||||
const constant int& out_vec_size [[buffer(8)]], \
|
||||
const constant int& batch_ndims [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* lhs_strides [[buffer(11)]], \
|
||||
const constant size_t* rhs_strides [[buffer(12)]], \
|
||||
const constant int& x_batch_ndims [[buffer(13)]], \
|
||||
const constant int* x_shape [[buffer(14)]], \
|
||||
const constant size_t* x_strides [[buffer(15)]], \
|
||||
const constant int& w_batch_ndims [[buffer(16)]], \
|
||||
const constant int* w_shape [[buffer(17)]], \
|
||||
const constant size_t* w_strides [[buffer(18)]], \
|
||||
const constant size_t* s_strides [[buffer(19)]], \
|
||||
const constant size_t* b_strides [[buffer(20)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_bs_qmv_fast_types(group_size, bits, packs_per_thread) \
|
||||
instantiate_bs_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
||||
instantiate_bs_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
||||
instantiate_bs_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_bs_qmv_fast_types(128, 2, 1)
|
||||
instantiate_bs_qmv_fast_types(128, 4, 2)
|
||||
instantiate_bs_qmv_fast_types(128, 8, 2)
|
||||
instantiate_bs_qmv_fast_types( 64, 2, 1)
|
||||
instantiate_bs_qmv_fast_types( 64, 4, 2)
|
||||
instantiate_bs_qmv_fast_types( 64, 8, 2)
|
||||
instantiate_bs_qmv_fast_types( 32, 2, 1)
|
||||
instantiate_bs_qmv_fast_types( 32, 4, 2)
|
||||
instantiate_bs_qmv_fast_types( 32, 8, 2) // clang-format on
|
||||
|
||||
#define instantiate_bs_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("bs_qmv_" #name "_gs_" #group_size \
|
||||
"_b_" #bits)]] [[kernel]] void \
|
||||
bs_qmv<itype, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
const device itype* x [[buffer(3)]], \
|
||||
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||
device itype* y [[buffer(6)]], \
|
||||
const constant int& in_vec_size [[buffer(7)]], \
|
||||
const constant int& out_vec_size [[buffer(8)]], \
|
||||
const constant int& batch_ndims [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* lhs_strides [[buffer(11)]], \
|
||||
const constant size_t* rhs_strides [[buffer(12)]], \
|
||||
const constant int& x_batch_ndims [[buffer(13)]], \
|
||||
const constant int* x_shape [[buffer(14)]], \
|
||||
const constant size_t* x_strides [[buffer(15)]], \
|
||||
const constant int& w_batch_ndims [[buffer(16)]], \
|
||||
const constant int* w_shape [[buffer(17)]], \
|
||||
const constant size_t* w_strides [[buffer(18)]], \
|
||||
const constant size_t* s_strides [[buffer(19)]], \
|
||||
const constant size_t* b_strides [[buffer(20)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_bs_qmv_types(group_size, bits) \
|
||||
instantiate_bs_qmv(float32, float, group_size, bits) \
|
||||
instantiate_bs_qmv(float16, half, group_size, bits) \
|
||||
instantiate_bs_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_bs_qmv_types(128, 2)
|
||||
instantiate_bs_qmv_types(128, 4)
|
||||
instantiate_bs_qmv_types(128, 8)
|
||||
instantiate_bs_qmv_types( 64, 2)
|
||||
instantiate_bs_qmv_types( 64, 4)
|
||||
instantiate_bs_qmv_types( 64, 8)
|
||||
instantiate_bs_qmv_types( 32, 2)
|
||||
instantiate_bs_qmv_types( 32, 4)
|
||||
instantiate_bs_qmv_types( 32, 8) // clang-format on
|
||||
|
||||
#define instantiate_bs_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("bs_qvm_" #name "_gs_" #group_size \
|
||||
"_b_" #bits)]] [[kernel]] void \
|
||||
bs_qvm<itype, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||
device itype* y [[buffer(6)]], \
|
||||
const constant int& in_vec_size [[buffer(7)]], \
|
||||
const constant int& out_vec_size [[buffer(8)]], \
|
||||
const constant int& batch_ndims [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* lhs_strides [[buffer(11)]], \
|
||||
const constant size_t* rhs_strides [[buffer(12)]], \
|
||||
const constant int& x_batch_ndims [[buffer(13)]], \
|
||||
const constant int* x_shape [[buffer(14)]], \
|
||||
const constant size_t* x_strides [[buffer(15)]], \
|
||||
const constant int& w_batch_ndims [[buffer(16)]], \
|
||||
const constant int* w_shape [[buffer(17)]], \
|
||||
const constant size_t* w_strides [[buffer(18)]], \
|
||||
const constant size_t* s_strides [[buffer(19)]], \
|
||||
const constant size_t* b_strides [[buffer(20)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_bs_qvm_types(group_size, bits) \
|
||||
instantiate_bs_qvm(float32, float, group_size, bits) \
|
||||
instantiate_bs_qvm(float16, half, group_size, bits) \
|
||||
instantiate_bs_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_bs_qvm_types(128, 2)
|
||||
instantiate_bs_qvm_types(128, 4)
|
||||
instantiate_bs_qvm_types(128, 8)
|
||||
instantiate_bs_qvm_types( 64, 2)
|
||||
instantiate_bs_qvm_types( 64, 4)
|
||||
instantiate_bs_qvm_types( 64, 8)
|
||||
instantiate_bs_qvm_types( 32, 2)
|
||||
instantiate_bs_qvm_types( 32, 4)
|
||||
instantiate_bs_qvm_types( 32, 8) // clang-format on
|
||||
|
||||
#define instantiate_bs_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("bs_qmm_t_" #name "_gs_" #group_size "_b_" #bits \
|
||||
"_alN_" #aligned_N)]] [[kernel]] void \
|
||||
bs_qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||
device itype* y [[buffer(6)]], \
|
||||
const constant int& M [[buffer(7)]], \
|
||||
const constant int& N [[buffer(8)]], \
|
||||
const constant int& K [[buffer(9)]], \
|
||||
const constant int& batch_ndims [[buffer(10)]], \
|
||||
const constant int* batch_shape [[buffer(11)]], \
|
||||
const constant size_t* lhs_strides [[buffer(12)]], \
|
||||
const constant size_t* rhs_strides [[buffer(13)]], \
|
||||
const constant int& x_batch_ndims [[buffer(14)]], \
|
||||
const constant int* x_shape [[buffer(15)]], \
|
||||
const constant size_t* x_strides [[buffer(16)]], \
|
||||
const constant int& w_batch_ndims [[buffer(17)]], \
|
||||
const constant int* w_shape [[buffer(18)]], \
|
||||
const constant size_t* w_strides [[buffer(19)]], \
|
||||
const constant size_t* s_strides [[buffer(20)]], \
|
||||
const constant size_t* b_strides [[buffer(21)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_bs_qmm_t_types(group_size, bits) \
|
||||
instantiate_bs_qmm_t(float32, float, group_size, bits, false) \
|
||||
instantiate_bs_qmm_t(float16, half, group_size, bits, false) \
|
||||
instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
||||
instantiate_bs_qmm_t(float32, float, group_size, bits, true) \
|
||||
instantiate_bs_qmm_t(float16, half, group_size, bits, true) \
|
||||
instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_bs_qmm_t_types(128, 2)
|
||||
instantiate_bs_qmm_t_types(128, 4)
|
||||
instantiate_bs_qmm_t_types(128, 8)
|
||||
instantiate_bs_qmm_t_types( 64, 2)
|
||||
instantiate_bs_qmm_t_types( 64, 4)
|
||||
instantiate_bs_qmm_t_types( 64, 8)
|
||||
instantiate_bs_qmm_t_types( 32, 2)
|
||||
instantiate_bs_qmm_t_types( 32, 4)
|
||||
instantiate_bs_qmm_t_types( 32, 8) // clang-format on
|
||||
|
||||
#define instantiate_bs_qmm_n(name, itype, group_size, bits) \
|
||||
template [[host_name("bs_qmm_n_" #name "_gs_" #group_size \
|
||||
"_b_" #bits)]] [[kernel]] void \
|
||||
bs_qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
const device itype* biases [[buffer(3)]], \
|
||||
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||
device itype* y [[buffer(6)]], \
|
||||
const constant int& M [[buffer(7)]], \
|
||||
const constant int& N [[buffer(8)]], \
|
||||
const constant int& K [[buffer(9)]], \
|
||||
const constant int& batch_ndims [[buffer(10)]], \
|
||||
const constant int* batch_shape [[buffer(11)]], \
|
||||
const constant size_t* lhs_strides [[buffer(12)]], \
|
||||
const constant size_t* rhs_strides [[buffer(13)]], \
|
||||
const constant int& x_batch_ndims [[buffer(14)]], \
|
||||
const constant int* x_shape [[buffer(15)]], \
|
||||
const constant size_t* x_strides [[buffer(16)]], \
|
||||
const constant int& w_batch_ndims [[buffer(17)]], \
|
||||
const constant int* w_shape [[buffer(18)]], \
|
||||
const constant size_t* w_strides [[buffer(19)]], \
|
||||
const constant size_t* s_strides [[buffer(20)]], \
|
||||
const constant size_t* b_strides [[buffer(21)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_bs_qmm_n_types(group_size, bits) \
|
||||
instantiate_bs_qmm_n(float32, float, group_size, bits) \
|
||||
instantiate_bs_qmm_n(float16, half, group_size, bits) \
|
||||
instantiate_bs_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_bs_qmm_n_types(128, 2)
|
||||
instantiate_bs_qmm_n_types(128, 4)
|
||||
instantiate_bs_qmm_n_types(128, 8)
|
||||
instantiate_bs_qmm_n_types( 64, 2)
|
||||
instantiate_bs_qmm_n_types( 64, 4)
|
||||
instantiate_bs_qmm_n_types( 64, 8)
|
||||
instantiate_bs_qmm_n_types( 32, 2)
|
||||
instantiate_bs_qmm_n_types( 32, 4)
|
||||
instantiate_bs_qmm_n_types( 32, 8) // clang-format on
|
||||
|
@@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
@@ -82,7 +82,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
@@ -140,7 +140,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
@@ -196,4 +196,289 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
// TODO: collapse batch dims
|
||||
auto& batch_shape = lhs_indices.shape();
|
||||
int batch_ndims = batch_shape.size();
|
||||
auto& lhs_strides = lhs_indices.strides();
|
||||
auto& rhs_strides = rhs_indices.strides();
|
||||
|
||||
// Ensure that the last two dims are row contiguous.
|
||||
// TODO: Check if we really need this for x as well...
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto& x_shape = x.shape();
|
||||
auto& x_strides = x.strides();
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
auto& w_shape = w.shape();
|
||||
auto& w_strides = w.strides();
|
||||
auto& s_strides = scales.strides();
|
||||
auto& b_strides = biases.strides();
|
||||
|
||||
int D = x.shape(-1);
|
||||
int B = x.shape(-2);
|
||||
int O = out.shape(-1);
|
||||
int N = out.size() / B / O;
|
||||
if (transpose_) {
|
||||
// Route to the fast bs_qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the bs_qmm_t
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_
|
||||
<< "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
} else {
|
||||
// Route to the bs_qvm kernel
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to bs_qmm_n
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_
|
||||
<< "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int wn = 2;
|
||||
int wm = 2;
|
||||
int bm = 32;
|
||||
int bn = 32;
|
||||
int bk = 32;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
if ((O % bn) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The output size should be divisible by "
|
||||
<< bn << " but received " << O << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user