Block sparse qmm (#1124)

This commit is contained in:
Angelos Katharopoulos 2024-05-16 15:24:14 -07:00 committed by GitHub
parent 1873ffda01
commit e78a6518fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1724 additions and 164 deletions

View File

@ -33,6 +33,7 @@ DEFAULT(ArgSort)
DEFAULT(AsStrided) DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM) DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM) DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)

View File

@ -44,6 +44,7 @@ DEFAULT(AsStrided)
DEFAULT(Broadcast) DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM) DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM) DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)

View File

@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
} }
void _qmm_dispatch( void _qmm_dispatch(
array out, array& out,
const array& x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
@ -253,6 +253,81 @@ void _qmm_dispatch(
} }
} }
void _bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
} // namespace } // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
} }
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -378,14 +378,14 @@ struct QuantizedBlockLoader {
}; };
template <typename T, int group_size, int bits, int packs_per_thread> template <typename T, int group_size, int bits, int packs_per_thread>
[[kernel]] void qmv_fast( METAL_FUNC void qmv_fast_impl(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w,
const device T* scales [[buffer(1)]], const device T* scales,
const device T* biases [[buffer(2)]], const device T* biases,
const device T* x [[buffer(3)]], const device T* x,
device T* y [[buffer(4)]], device T* y,
const constant int& in_vec_size [[buffer(5)]], const constant int& in_vec_size,
const constant int& out_vec_size [[buffer(6)]], const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@ -404,13 +404,13 @@ template <typename T, int group_size, int bits, int packs_per_thread>
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup; simd_gid * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread; w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.z * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + out_row; y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) { for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
@ -440,15 +440,15 @@ template <typename T, int group_size, int bits, int packs_per_thread>
} }
} }
template <typename T, const int group_size, const int bits> template <typename T, int group_size, int bits>
[[kernel]] void qmv( METAL_FUNC void qmv_impl(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w,
const device T* scales [[buffer(1)]], const device T* scales,
const device T* biases [[buffer(2)]], const device T* biases,
const device T* x [[buffer(3)]], const device T* x,
device T* y [[buffer(4)]], device T* y,
const constant int& in_vec_size [[buffer(5)]], const constant int& in_vec_size,
const constant int& out_vec_size [[buffer(6)]], const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@ -468,7 +468,7 @@ template <typename T, const int group_size, const int bits>
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size; const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup; simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
@ -482,8 +482,8 @@ template <typename T, const int group_size, const int bits>
w += out_row * in_vec_size_w + simd_lid * packs_per_thread; w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.z * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + out_row; y += tid.y * out_vec_size + out_row;
int k = 0; int k = 0;
for (; k < in_vec_size - block_size; k += block_size) { for (; k < in_vec_size - block_size; k += block_size) {
@ -537,8 +537,8 @@ template <typename T, const int group_size, const int bits>
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.z * in_vec_size + simd_lid * values_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + used_out_row; y += tid.y * out_vec_size + used_out_row;
int k = 0; int k = 0;
for (; k < in_vec_size - block_size; k += block_size) { for (; k < in_vec_size - block_size; k += block_size) {
@ -590,14 +590,14 @@ template <typename T, const int group_size, const int bits>
} }
template <typename T, const int group_size, const int bits> template <typename T, const int group_size, const int bits>
[[kernel]] void qvm( METAL_FUNC void qvm_impl(
const device T* x [[buffer(0)]], const device T* x,
const device uint32_t* w [[buffer(1)]], const device uint32_t* w,
const device T* scales [[buffer(2)]], const device T* scales,
const device T* biases [[buffer(3)]], const device T* biases,
device T* y [[buffer(4)]], device T* y,
const constant int& in_vec_size [[buffer(5)]], const constant int& in_vec_size,
const constant int& out_vec_size [[buffer(6)]], const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@ -616,12 +616,12 @@ template <typename T, const int group_size, const int bits>
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size / pack_factor; const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size; const int out_vec_size_g = out_vec_size / group_size;
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor; int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
w += out_col / pack_factor; w += out_col / pack_factor;
scales += out_col / group_size; scales += out_col / group_size;
biases += out_col / group_size; biases += out_col / group_size;
x += tid.z * in_vec_size; x += tid.y * in_vec_size;
y += tid.z * out_vec_size + out_col; y += tid.y * out_vec_size + out_col;
if (out_col >= out_vec_size) { if (out_col >= out_vec_size) {
return; return;
@ -675,15 +675,17 @@ template <
const int group_size, const int group_size,
const int bits, const int bits,
const bool aligned_N> const bool aligned_N>
[[kernel]] void qmm_t( METAL_FUNC void qmm_t_impl(
const device T* x [[buffer(0)]], const device T* x,
const device uint32_t* w [[buffer(1)]], const device uint32_t* w,
const device T* scales [[buffer(2)]], const device T* scales,
const device T* biases [[buffer(3)]], const device T* biases,
device T* y [[buffer(4)]], device T* y,
const constant int& M [[buffer(5)]], threadgroup T* Xs,
const constant int& N [[buffer(6)]], threadgroup T* Ws,
const constant int& K [[buffer(7)]], const constant int& M,
const constant int& N,
const constant int& K,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -713,9 +715,6 @@ template <
group_size, group_size,
bits>; bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
// Set the block // Set the block
const int K_w = K / pack_factor; const int K_w = K / pack_factor;
const int K_g = K / group_size; const int K_g = K / group_size;
@ -797,15 +796,17 @@ template <
const int BN, const int BN,
const int group_size, const int group_size,
const int bits> const int bits>
[[kernel]] void qmm_n( METAL_FUNC void qmm_n_impl(
const device T* x [[buffer(0)]], const device T* x,
const device uint32_t* w [[buffer(1)]], const device uint32_t* w,
const device T* scales [[buffer(2)]], const device T* scales,
const device T* biases [[buffer(3)]], const device T* biases,
device T* y [[buffer(4)]], device T* y,
const constant int& M [[buffer(5)]], threadgroup T* Xs,
const constant int& N [[buffer(6)]], threadgroup T* Ws,
const constant int& K [[buffer(7)]], const constant int& M,
const constant int& N,
const constant int& K,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]], uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -836,9 +837,6 @@ template <
group_size, group_size,
bits>; bits>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
// Set the block // Set the block
const int y_row = tid.y * BM; const int y_row = tid.y * BM;
const int y_col = tid.x * BN; const int y_col = tid.x * BN;
@ -923,6 +921,518 @@ template <
} }
} }
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device T*& scales,
const device T*& biases,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T*& y,
int output_stride,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant size_t* lhs_strides,
const constant size_t* rhs_strides,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx;
uint32_t w_idx;
if (batch_ndims == 1) {
x_idx = lhs_indices[tid.z * lhs_strides[0]];
w_idx = rhs_indices[tid.z * rhs_strides[0]];
} else {
ulong2 idx = elem_to_loc_broadcast(
tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
x_idx = lhs_indices[idx.x];
w_idx = rhs_indices[idx.y];
}
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
biases += w_idx * b_strides[0];
} else {
ulong3 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
biases += idx.z;
}
y += tid.z * output_stride;
}
template <typename T, int group_size, int bits, int packs_per_thread>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
qmv_fast_impl<T, group_size, bits, packs_per_thread>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
qmv_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
qvm_impl<T, group_size, bits>(
x,
w,
scales,
biases,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits,
const bool aligned_N>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits>
[[kernel]] void qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits, int packs_per_thread>
[[kernel]] void bs_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
y,
out_vec_size,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
qmv_fast_impl<T, group_size, bits, packs_per_thread>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
y,
out_vec_size,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
qmv_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
y,
out_vec_size,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
qvm_impl<T, group_size, bits>(
x,
w,
scales,
biases,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits,
const bool aligned_N>
[[kernel]] void bs_qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& K [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]],
const constant int* x_shape [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]],
const constant int* w_shape [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
y,
M * N,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits>
[[kernel]] void bs_qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& K [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]],
const constant int* x_shape [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]],
const constant int* w_shape [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
y,
M * N,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \ #define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \ template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
"_fast")]] [[kernel]] void \ "_fast")]] [[kernel]] void \
@ -1089,3 +1599,241 @@ instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2) instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4) instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8) // clang-format on instantiate_qmm_n_types( 32, 8) // clang-format on
#define instantiate_bs_qmv_fast( \
name, itype, group_size, bits, packs_per_thread) \
template [[host_name("bs_qmv_" #name "_gs_" #group_size "_b_" #bits \
"_fast")]] [[kernel]] void \
bs_qmv_fast<itype, group_size, bits, packs_per_thread>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
const device uint32_t* lhs_indices [[buffer(4)]], \
const device uint32_t* rhs_indices [[buffer(5)]], \
device itype* y [[buffer(6)]], \
const constant int& in_vec_size [[buffer(7)]], \
const constant int& out_vec_size [[buffer(8)]], \
const constant int& batch_ndims [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* lhs_strides [[buffer(11)]], \
const constant size_t* rhs_strides [[buffer(12)]], \
const constant int& x_batch_ndims [[buffer(13)]], \
const constant int* x_shape [[buffer(14)]], \
const constant size_t* x_strides [[buffer(15)]], \
const constant int& w_batch_ndims [[buffer(16)]], \
const constant int* w_shape [[buffer(17)]], \
const constant size_t* w_strides [[buffer(18)]], \
const constant size_t* s_strides [[buffer(19)]], \
const constant size_t* b_strides [[buffer(20)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_bs_qmv_fast_types(group_size, bits, packs_per_thread) \
instantiate_bs_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
instantiate_bs_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
instantiate_bs_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
// clang-format off
instantiate_bs_qmv_fast_types(128, 2, 1)
instantiate_bs_qmv_fast_types(128, 4, 2)
instantiate_bs_qmv_fast_types(128, 8, 2)
instantiate_bs_qmv_fast_types( 64, 2, 1)
instantiate_bs_qmv_fast_types( 64, 4, 2)
instantiate_bs_qmv_fast_types( 64, 8, 2)
instantiate_bs_qmv_fast_types( 32, 2, 1)
instantiate_bs_qmv_fast_types( 32, 4, 2)
instantiate_bs_qmv_fast_types( 32, 8, 2) // clang-format on
#define instantiate_bs_qmv(name, itype, group_size, bits) \
template [[host_name("bs_qmv_" #name "_gs_" #group_size \
"_b_" #bits)]] [[kernel]] void \
bs_qmv<itype, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
const device uint32_t* lhs_indices [[buffer(4)]], \
const device uint32_t* rhs_indices [[buffer(5)]], \
device itype* y [[buffer(6)]], \
const constant int& in_vec_size [[buffer(7)]], \
const constant int& out_vec_size [[buffer(8)]], \
const constant int& batch_ndims [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* lhs_strides [[buffer(11)]], \
const constant size_t* rhs_strides [[buffer(12)]], \
const constant int& x_batch_ndims [[buffer(13)]], \
const constant int* x_shape [[buffer(14)]], \
const constant size_t* x_strides [[buffer(15)]], \
const constant int& w_batch_ndims [[buffer(16)]], \
const constant int* w_shape [[buffer(17)]], \
const constant size_t* w_strides [[buffer(18)]], \
const constant size_t* s_strides [[buffer(19)]], \
const constant size_t* b_strides [[buffer(20)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_bs_qmv_types(group_size, bits) \
instantiate_bs_qmv(float32, float, group_size, bits) \
instantiate_bs_qmv(float16, half, group_size, bits) \
instantiate_bs_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_bs_qmv_types(128, 2)
instantiate_bs_qmv_types(128, 4)
instantiate_bs_qmv_types(128, 8)
instantiate_bs_qmv_types( 64, 2)
instantiate_bs_qmv_types( 64, 4)
instantiate_bs_qmv_types( 64, 8)
instantiate_bs_qmv_types( 32, 2)
instantiate_bs_qmv_types( 32, 4)
instantiate_bs_qmv_types( 32, 8) // clang-format on
#define instantiate_bs_qvm(name, itype, group_size, bits) \
template [[host_name("bs_qvm_" #name "_gs_" #group_size \
"_b_" #bits)]] [[kernel]] void \
bs_qvm<itype, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
const device uint32_t* lhs_indices [[buffer(4)]], \
const device uint32_t* rhs_indices [[buffer(5)]], \
device itype* y [[buffer(6)]], \
const constant int& in_vec_size [[buffer(7)]], \
const constant int& out_vec_size [[buffer(8)]], \
const constant int& batch_ndims [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* lhs_strides [[buffer(11)]], \
const constant size_t* rhs_strides [[buffer(12)]], \
const constant int& x_batch_ndims [[buffer(13)]], \
const constant int* x_shape [[buffer(14)]], \
const constant size_t* x_strides [[buffer(15)]], \
const constant int& w_batch_ndims [[buffer(16)]], \
const constant int* w_shape [[buffer(17)]], \
const constant size_t* w_strides [[buffer(18)]], \
const constant size_t* s_strides [[buffer(19)]], \
const constant size_t* b_strides [[buffer(20)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_bs_qvm_types(group_size, bits) \
instantiate_bs_qvm(float32, float, group_size, bits) \
instantiate_bs_qvm(float16, half, group_size, bits) \
instantiate_bs_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_bs_qvm_types(128, 2)
instantiate_bs_qvm_types(128, 4)
instantiate_bs_qvm_types(128, 8)
instantiate_bs_qvm_types( 64, 2)
instantiate_bs_qvm_types( 64, 4)
instantiate_bs_qvm_types( 64, 8)
instantiate_bs_qvm_types( 32, 2)
instantiate_bs_qvm_types( 32, 4)
instantiate_bs_qvm_types( 32, 8) // clang-format on
#define instantiate_bs_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("bs_qmm_t_" #name "_gs_" #group_size "_b_" #bits \
"_alN_" #aligned_N)]] [[kernel]] void \
bs_qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
const device uint32_t* lhs_indices [[buffer(4)]], \
const device uint32_t* rhs_indices [[buffer(5)]], \
device itype* y [[buffer(6)]], \
const constant int& M [[buffer(7)]], \
const constant int& N [[buffer(8)]], \
const constant int& K [[buffer(9)]], \
const constant int& batch_ndims [[buffer(10)]], \
const constant int* batch_shape [[buffer(11)]], \
const constant size_t* lhs_strides [[buffer(12)]], \
const constant size_t* rhs_strides [[buffer(13)]], \
const constant int& x_batch_ndims [[buffer(14)]], \
const constant int* x_shape [[buffer(15)]], \
const constant size_t* x_strides [[buffer(16)]], \
const constant int& w_batch_ndims [[buffer(17)]], \
const constant int* w_shape [[buffer(18)]], \
const constant size_t* w_strides [[buffer(19)]], \
const constant size_t* s_strides [[buffer(20)]], \
const constant size_t* b_strides [[buffer(21)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_bs_qmm_t_types(group_size, bits) \
instantiate_bs_qmm_t(float32, float, group_size, bits, false) \
instantiate_bs_qmm_t(float16, half, group_size, bits, false) \
instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
instantiate_bs_qmm_t(float32, float, group_size, bits, true) \
instantiate_bs_qmm_t(float16, half, group_size, bits, true) \
instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
// clang-format off
instantiate_bs_qmm_t_types(128, 2)
instantiate_bs_qmm_t_types(128, 4)
instantiate_bs_qmm_t_types(128, 8)
instantiate_bs_qmm_t_types( 64, 2)
instantiate_bs_qmm_t_types( 64, 4)
instantiate_bs_qmm_t_types( 64, 8)
instantiate_bs_qmm_t_types( 32, 2)
instantiate_bs_qmm_t_types( 32, 4)
instantiate_bs_qmm_t_types( 32, 8) // clang-format on
#define instantiate_bs_qmm_n(name, itype, group_size, bits) \
template [[host_name("bs_qmm_n_" #name "_gs_" #group_size \
"_b_" #bits)]] [[kernel]] void \
bs_qmm_n<itype, 32, 32, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
const device uint32_t* lhs_indices [[buffer(4)]], \
const device uint32_t* rhs_indices [[buffer(5)]], \
device itype* y [[buffer(6)]], \
const constant int& M [[buffer(7)]], \
const constant int& N [[buffer(8)]], \
const constant int& K [[buffer(9)]], \
const constant int& batch_ndims [[buffer(10)]], \
const constant int* batch_shape [[buffer(11)]], \
const constant size_t* lhs_strides [[buffer(12)]], \
const constant size_t* rhs_strides [[buffer(13)]], \
const constant int& x_batch_ndims [[buffer(14)]], \
const constant int* x_shape [[buffer(15)]], \
const constant size_t* x_strides [[buffer(16)]], \
const constant int& w_batch_ndims [[buffer(17)]], \
const constant int* w_shape [[buffer(18)]], \
const constant size_t* w_strides [[buffer(19)]], \
const constant size_t* s_strides [[buffer(20)]], \
const constant size_t* b_strides [[buffer(21)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_bs_qmm_n_types(group_size, bits) \
instantiate_bs_qmm_n(float32, float, group_size, bits) \
instantiate_bs_qmm_n(float16, half, group_size, bits) \
instantiate_bs_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
// clang-format off
instantiate_bs_qmm_n_types(128, 2)
instantiate_bs_qmm_n_types(128, 4)
instantiate_bs_qmm_n_types(128, 8)
instantiate_bs_qmm_n_types( 64, 2)
instantiate_bs_qmm_n_types( 64, 4)
instantiate_bs_qmm_n_types( 64, 8)
instantiate_bs_qmm_n_types( 32, 2)
instantiate_bs_qmm_n_types( 32, 4)
instantiate_bs_qmm_n_types( 32, 8) // clang-format on

View File

@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bo = 8; int bo = 8;
int bd = 32; int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1); MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B); MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(scales, 1);
@ -82,7 +82,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bo = 8; int bo = 8;
int bd = 32; int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1); MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(scales, 1);
@ -140,7 +140,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bo = 8; int bo = 8;
int bd = 32; int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1); MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
@ -196,4 +196,289 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
} }
void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
// TODO: collapse batch dims
auto& batch_shape = lhs_indices.shape();
int batch_ndims = batch_shape.size();
auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides();
// Ensure that the last two dims are row contiguous.
// TODO: Check if we really need this for x as well...
std::vector<array> copies;
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
int x_batch_ndims = x.ndim() - 2;
auto& x_shape = x.shape();
auto& x_strides = x.strides();
int w_batch_ndims = w.ndim() - 2;
auto& w_shape = w.shape();
auto& w_strides = w.strides();
auto& s_strides = scales.strides();
auto& b_strides = biases.strides();
int D = x.shape(-1);
int B = x.shape(-2);
int O = out.shape(-1);
int N = out.size() / B / O;
if (transpose_) {
// Route to the fast bs_qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
else if (B < 6) {
std::ostringstream kname;
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the bs_qmm_t
else {
std::ostringstream kname;
kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_
<< "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&B, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&D, sizeof(int), 9);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, batch_shape, 11);
set_vector_bytes(compute_encoder, lhs_strides, 12);
set_vector_bytes(compute_encoder, rhs_strides, 13);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
set_vector_bytes(compute_encoder, x_shape, 15);
set_vector_bytes(compute_encoder, x_strides, 16);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
set_vector_bytes(compute_encoder, w_shape, 18);
set_vector_bytes(compute_encoder, w_strides, 19);
set_vector_bytes(compute_encoder, s_strides, 20);
set_vector_bytes(compute_encoder, b_strides, 21);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
} else {
// Route to the bs_qvm kernel
if (B < 4) {
std::ostringstream kname;
kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to bs_qmm_n
else {
std::ostringstream kname;
kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_
<< "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
if ((O % bn) != 0) {
std::ostringstream msg;
msg << "[quantized_matmul] The output size should be divisible by "
<< bn << " but received " << O << ".";
throw std::runtime_error(msg.str());
}
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&B, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&D, sizeof(int), 9);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, batch_shape, 11);
set_vector_bytes(compute_encoder, lhs_strides, 12);
set_vector_bytes(compute_encoder, rhs_strides, 13);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
set_vector_bytes(compute_encoder, x_shape, 15);
set_vector_bytes(compute_encoder, x_strides, 16);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
set_vector_bytes(compute_encoder, w_shape, 18);
set_vector_bytes(compute_encoder, w_strides, 19);
set_vector_bytes(compute_encoder, s_strides, 20);
set_vector_bytes(compute_encoder, b_strides, 21);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -35,6 +35,7 @@ NO_GPU(AsStrided)
NO_GPU(BitwiseBinary) NO_GPU(BitwiseBinary)
NO_GPU(BlockMaskedMM) NO_GPU(BlockMaskedMM)
NO_GPU(BlockSparseMM) NO_GPU(BlockSparseMM)
NO_GPU(BlockSparseQMM)
NO_GPU(Broadcast) NO_GPU(Broadcast)
NO_GPU(Ceil) NO_GPU(Ceil)
NO_GPU_MULTI(Compiled) NO_GPU_MULTI(Compiled)

View File

@ -50,6 +50,83 @@ Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32); return issubdtype(d, inexact) ? d : promote_types(d, float32);
} }
array indices_or_default(
std::optional<array> indices,
const array& x,
StreamOrDevice s) {
if (indices.has_value()) {
return indices.value();
}
std::vector<int> shape(x.shape().begin(), x.shape().end() - 2);
int total =
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
return reshape(arange(total, uint32, s), shape, s);
}
std::pair<int, int> extract_quantized_matmul_dims(
std::string_view tag,
const array& x,
const array& w,
const array& scales,
const array& biases,
bool transpose,
int group_size,
int bits) {
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 "
<< "but received" << w.dtype();
throw std::invalid_argument(msg.str());
}
if (scales.shape() != biases.shape()) {
std::ostringstream msg;
msg << "[" << tag << "] Scales and biases should have the same shape. "
<< "Received scales with shape " << scales.shape()
<< " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
if (!std::equal(
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
std::ostringstream msg;
msg << "[" << tag
<< "] Weight, scales and biases should have the same batch shape. "
<< "Received weight with shape " << w.shape() << ", scales with "
<< scales.shape() << " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
int x_inner_dims = x.shape(-1);
// Calculate the expanded w's dims
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg;
msg << "[" << tag << "] Last dimension of first input with "
<< "shape (..., " << x_inner_dims << ") does not match "
<< "the expanded quantized matrix (" << w_inner_dims << ", "
<< w_outer_dims << ") computed from shape " << w.shape()
<< " with group_size=" << group_size << ", bits=" << bits
<< " and transpose=" << std::boolalpha << transpose;
throw std::invalid_argument(msg.str());
}
return {w_inner_dims, w_outer_dims};
}
} // namespace } // namespace
array arange( array arange(
@ -3203,7 +3280,7 @@ array conv_general(
} }
array quantized_matmul( array quantized_matmul(
const array& in_x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const array& biases,
@ -3211,13 +3288,10 @@ array quantized_matmul(
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
array x = in_x; // Check and extract the quantized matrix shape against x
if (w.dtype() != uint32) { auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
std::ostringstream msg; "quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
msg << "[quantized_matmul] The weight matrix should be uint32 "
<< "but received" << w.dtype();
throw std::invalid_argument(msg.str());
}
if (w.ndim() != 2) { if (w.ndim() != 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantized_matmul] Batched quantized matmul is not supported for now " msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
@ -3225,42 +3299,6 @@ array quantized_matmul(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Keep x's batch dimensions to reshape it back after the matmul
auto original_shape = x.shape();
int x_inner_dims = original_shape.back();
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
std::ostringstream msg;
msg << "[quantized_matmul] Scales and biases should have the same 2D shape. "
<< "Received scales with shape " << scales.shape()
<< " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
if (w.shape(1) * 32 / bits != scales.shape(1) * group_size) {
std::ostringstream msg;
msg << "[quantized_matmul] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
// Calculate the expanded w's dims
int w_inner_dims = (transpose) ? w.shape(1) * 32 / bits : w.shape(0);
int w_outer_dims = (transpose) ? w.shape(0) : w.shape(1) * 32 / bits;
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg;
msg << "[quantized_matmul] Last dimension of first input with "
<< "shape (..., " << x_inner_dims << ") does not match "
<< "the expanded quantized matrix (" << w_inner_dims << ", "
<< w_outer_dims << ") computed from shape " << w.shape()
<< " with group_size=" << group_size << ", bits=" << bits
<< " and transpose=" << std::boolalpha << transpose;
throw std::invalid_argument(msg.str());
}
auto dtype = result_type(x, scales, biases); auto dtype = result_type(x, scales, biases);
if (!issubdtype(dtype, floating)) { if (!issubdtype(dtype, floating)) {
std::ostringstream msg; std::ostringstream msg;
@ -3270,10 +3308,11 @@ array quantized_matmul(
<< " and biases.dtype() == " << biases.dtype(); << " and biases.dtype() == " << biases.dtype();
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
std::vector<array> inputs;
original_shape.back() = w_outer_dims; auto out_shape = x.shape();
out_shape.back() = w_outer_dims;
return array( return array(
std::move(original_shape), std::move(out_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose), to_stream(s), group_size, bits, transpose),
@ -3302,11 +3341,14 @@ std::tuple<array, array, array> quantize(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (w.ndim() != 2) { if (w.ndim() < 2) {
throw std::invalid_argument("[quantize] Only matrices supported for now"); std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
} }
if ((w.shape(1) % group_size) != 0) { if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by " msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size << "the quantization group size " << group_size
@ -3327,7 +3369,7 @@ std::tuple<array, array, array> quantize(
// at least we bail out early which will result in a nice readable error. // at least we bail out early which will result in a nice readable error.
// //
// Hopefully nobody is quantizing matrices that small anyway. // Hopefully nobody is quantizing matrices that small anyway.
if (w.shape(1) < 32 * el_per_int) { if (w.shape(-1) < 32 * el_per_int) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is " msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
<< "too small for quantization. We support >=512 for 2 bits, " << "too small for quantization. We support >=512 for 2 bits, "
@ -3336,9 +3378,12 @@ std::tuple<array, array, array> quantize(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Prepare the shape for the outputs.
auto wshape = w.shape();
wshape.back() = -1;
// Compute scales and biases // Compute scales and biases
array packed_w = array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
@ -3357,12 +3402,14 @@ std::tuple<array, array, array> quantize(
zero, zero,
n_bins), n_bins),
uint32); uint32);
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s); packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum( packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return std::make_tuple( return std::make_tuple(
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s)); reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
reshape(biases, wshape, s));
} }
array dequantize( array dequantize(
@ -3382,11 +3429,21 @@ array dequantize(
msg << "[dequantize] Invalid value for group_size: " << group_size; msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) { if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
throw std::invalid_argument("[dequantize] Only matrices supported for now"); std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
} }
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) { auto wshape = w.shape();
auto sshape = scales.shape();
auto bshape = biases.shape();
wshape.back() = -1;
sshape.back() = -1;
bshape.back() = -1;
if (wshape != sshape || wshape != bshape) {
throw std::invalid_argument( throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix"); "[dequantize] Shape of scales and biases does not match the matrix");
} }
@ -3399,7 +3456,7 @@ array dequantize(
// Compute some constants for the dequantization // Compute some constants for the dequantization
int el_per_int = 32 / bits; int el_per_int = 32 / bits;
if (w.shape(1) * el_per_int != scales.shape(1) * group_size) { if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
std::ostringstream msg; std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix " msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape " << "given the quantization parameters. Provided matrix of shape "
@ -3411,25 +3468,79 @@ array dequantize(
// Extract the pieces from the passed quantized matrix // Extract the pieces from the passed quantized matrix
std::vector<array> parts; std::vector<array> parts;
for (int start = 0; start < 32; start += bits) { for (int start = 0; start < 32; start += bits) {
// TODO: Implement bitwise operators for integral types
int shift_left = 32 - (start + bits); int shift_left = 32 - (start + bits);
int shift_right = shift_left + start; int shift_right = shift_left + start;
array p = multiply(w, array(1 << shift_left, uint32), s);
p = floor_divide(p, array(1 << shift_right, uint32), s); parts.push_back(expand_dims(
p = expand_dims(p, -1, s); right_shift(
parts.push_back(p); left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
} }
array w_full = concatenate(parts, -1, s); array w_full = concatenate(parts, -1, s);
// Dequantize // Dequantize
w_full = reshape(w_full, {w.shape(0), -1, group_size}, s); wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s); w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s); w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, {w.shape(0), -1}, s); w_full = reshape(w_full, sshape, s);
return w_full; return w_full;
} }
array block_sparse_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */,
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s);
}
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits);
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
array rhs_indices = indices_or_default(rhs_indices_, w, s);
auto out_bsx_shape =
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
// Compute the full output shape
auto out_shape = out_bsx_shape;
out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer_dims);
// and output type
auto out_type = result_type(x, scales, biases);
auto out = array(
std::move(out_shape),
out_type,
std::make_shared<BlockSparseQMM>(
to_stream(s), group_size, bits, transpose),
{astype(x, out_type, s),
w,
astype(scales, out_type, s),
astype(biases, out_type, s),
lhs_indices,
rhs_indices});
return out;
}
array tensordot( array tensordot(
const array& a, const array& a,
const array& b, const array& b,
@ -3879,24 +3990,8 @@ array block_sparse_mm(
b = astype(b, out_type, s); b = astype(b, out_type, s);
// Handle broadcasting // Handle broadcasting
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2); array lhs_indices = indices_or_default(lhs_indices_, a, s);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2); array rhs_indices = indices_or_default(rhs_indices_, b, s);
auto indices_or_default = [&](const std::optional<array>& indices,
const std::vector<int>& bsx_shape) {
if (indices.has_value()) {
return indices.value();
} else {
int n_batch = 1;
for (auto& i : bsx_shape)
n_batch *= i;
return reshape(arange(n_batch, uint32, s), bsx_shape, s);
}
};
// Pull and broadcast indices
array lhs_indices = indices_or_default(lhs_indices_, bsx_a);
array rhs_indices = indices_or_default(rhs_indices_, bsx_b);
if (!issubdtype(lhs_indices.dtype(), integer)) { if (!issubdtype(lhs_indices.dtype(), integer)) {
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -1157,6 +1157,19 @@ array dequantize(
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */
array block_sparse_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
bool transpose = true,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */ /** Returns a contraction of a and b over multiple dimensions. */
array tensordot( array tensordot(
const array& a, const array& a,

View File

@ -2372,7 +2372,85 @@ std::vector<array> QuantizedMatmul::jvp(
bool QuantizedMatmul::is_equivalent(const Primitive& other) const { bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other); const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_; return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
transpose_ == qm_other.transpose_;
}
std::pair<std::vector<array>, std::vector<int>> BlockSparseQMM::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("BlockSparseQMM::vmap NYI");
}
std::vector<array> BlockSparseQMM::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
auto& cotan = cotangents[0];
auto& x = primals[0];
auto& w = primals[1];
auto& scales = primals[2];
auto& biases = primals[3];
auto& lhs_indices = primals[4];
auto& rhs_indices = primals[5];
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
vjps.push_back(reshape(
scatter_add(
flatten(zeros_like(x, stream()), 0, -3, stream()),
lhs_indices,
expand_dims(
block_sparse_qmm(
cotan,
w,
scales,
biases,
std::nullopt,
rhs_indices,
!transpose_,
group_size_,
bits_,
stream()),
-3,
stream()),
0,
stream()),
x.shape(),
stream()));
}
// gradient wrt to the indices is undefined
else if (arg > 3) {
throw std::runtime_error(
"BlockSparseQMM::vjp cannot compute the gradient wrt the indices.");
}
// gradient wrt to w_q, scales or biases
else {
throw std::runtime_error(
"BlockSparseQMM::vjp no gradient wrt the quantized matrix yet.");
}
}
return vjps;
}
std::vector<array> BlockSparseQMM::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::runtime_error("BlockSparseQMM::jvp NYI");
}
bool BlockSparseQMM::is_equivalent(const Primitive& other) const {
const BlockSparseQMM& qm_other = static_cast<const BlockSparseQMM&>(other);
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
transpose_ == qm_other.transpose_;
} }
std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap( std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(

View File

@ -1467,6 +1467,34 @@ class QuantizedMatmul : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class BlockSparseQMM : public UnaryPrimitive {
public:
explicit BlockSparseQMM(
Stream stream,
int group_size,
int bits,
bool transpose)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(BlockSparseQMM)
bool is_equivalent(const Primitive& other) const override;
private:
int group_size_;
int bits_;
bool transpose_;
void eval(const std::vector<array>& inputs, array& out);
};
class RandomBits : public UnaryPrimitive { class RandomBits : public UnaryPrimitive {
public: public:
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width) explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)

View File

@ -4,6 +4,7 @@ import math
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedEmbedding
class Embedding(Module): class Embedding(Module):
@ -37,3 +38,7 @@ class Embedding(Module):
weights are tied. weights are tied.
""" """
return x @ self.weight.T return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)

View File

@ -5,6 +5,7 @@ from typing import Any
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedLinear
class Identity(Module): class Identity(Module):
@ -69,6 +70,10 @@ class Linear(Module):
x = x @ self["weight"].T x = x @ self["weight"].T
return x return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
class Bilinear(Module): class Bilinear(Module):
r"""Applies a bilinear transformation to the inputs. r"""Applies a bilinear transformation to the inputs.

View File

@ -5,8 +5,6 @@ from typing import Callable, Optional
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.utils import tree_map_with_path from mlx.utils import tree_map_with_path
@ -18,8 +16,9 @@ def quantize(
): ):
"""Quantize the sub-modules of a module according to a predicate. """Quantize the sub-modules of a module according to a predicate.
By default all :obj:`Linear` and :obj:`Embedding` layers will be By default all layers that define a ``to_quantized(group_size, bits)``
quantized. Note also, the module is updated in-place. method will be quantized. Both :obj:`Linear` and :obj:`Embedding` layers
will be quantized. Note also, the module is updated in-place.
Args: Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized. model (mlx.nn.Module): The model whose leaf modules may be quantized.
@ -30,18 +29,15 @@ def quantize(
class_predicate (Optional[Callable]): A callable which receives the class_predicate (Optional[Callable]): A callable which receives the
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if :obj:`Module` path and :obj:`Module` itself and returns ``True`` if
it should be quantized and ``False`` otherwise. If ``None``, then it should be quantized and ``False`` otherwise. If ``None``, then
all linear and embedding layers are quantized. Default: ``None``. all layers that define a ``to_quantized(group_size, bits)`` method
are quantized. Default: ``None``.
""" """
class_predicate = class_predicate or ( class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
lambda _, m: isinstance(m, (Linear, Embedding))
)
def _maybe_quantize(path, m): def _maybe_quantize(path, m):
if class_predicate(path, m): if class_predicate(path, m):
if isinstance(m, Linear): if hasattr(m, "to_quantized"):
return QuantizedLinear.from_linear(m, group_size, bits) return m.to_quantized(group_size, bits)
elif isinstance(m, Embedding):
return QuantizedEmbedding.from_embedding(m, group_size, bits)
else: else:
raise ValueError(f"Unable to quantize model of type {type(m)}") raise ValueError(f"Unable to quantize model of type {type(m)}")
else: else:
@ -129,7 +125,7 @@ class QuantizedEmbedding(Module):
@classmethod @classmethod
def from_embedding( def from_embedding(
cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4 cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
): ):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape embedding_dims, dims = embedding_layer.weight.shape
@ -220,7 +216,7 @@ class QuantizedLinear(Module):
return x return x
@classmethod @classmethod
def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4): def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits) ql = cls(input_dims, output_dims, False, group_size, bits)

View File

@ -3747,6 +3747,52 @@ void init_ops(nb::module_& m) {
Returns: Returns:
result (array): The dequantized version of ``w`` result (array): The dequantized version of ``w``
)pbdoc"); )pbdoc");
m.def(
"block_sparse_qmm",
&block_sparse_qmm,
nb::arg(),
nb::arg(),
"scales"_a,
"biases"_a,
"lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(),
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
This operation is the quantized equivalent to :func:`block_sparse_mm`.
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
all but the last two dimensions) of ``x`` and ``w`` respectively.
Note that ``scales`` and ``biases`` must have the same batch dimensions
as ``w`` since they represent the same quantized matrix.
Args:
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. (default: ``True``)
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
Returns:
result (array): The result of the multiplication of ``x`` with ``w``
after gathering using ``lhs_indices`` and ``rhs_indices``.
)pbdoc");
m.def( m.def(
"tensordot", "tensordot",
[](const array& a, [](const array& a,
@ -3933,7 +3979,7 @@ void init_ops(nb::module_& m) {
Matrix multiplication with matrix-level gather. Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``. This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively. The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.

View File

@ -277,6 +277,148 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_block_sparse_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
def test_shape(
M,
N,
K,
dtype=mx.float32,
batch_A=(),
batch_B=(),
lhs_indices=None,
rhs_indices=None,
transpose=True,
group_size=64,
bits=4,
):
with self.subTest(
M=M,
N=N,
K=K,
dtype=dtype,
batch_A=batch_A,
batch_B=batch_B,
lhs_indices=lhs_indices,
rhs_indices=rhs_indices,
transpose=transpose,
group_size=group_size,
bits=bits,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
if rhs_indices is not None:
rhs_indices = mx.array(rhs_indices)
c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices)
c2 = mx.block_sparse_qmm(
x,
qw,
s,
b,
lhs_indices,
rhs_indices,
transpose=transpose,
group_size=group_size,
bits=bits,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (1,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (2,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (3,),
"lhs_indices": (0, 2),
"batch_B": (1,),
"rhs_indices": (0,),
},
{
"batch_A": (5,),
"lhs_indices": (0, 2),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (4, 2),
"lhs_indices": (
(7, 6),
(5, 4),
(1, 2),
),
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
)
for kwargs in inputs:
test_shape(32, 32, 256, **kwargs)
test_shape(1, 32, 256, **kwargs)
test_shape(32, 256, 32, transpose=False, **kwargs)
test_shape(1, 256, 32, transpose=False, **kwargs)
test_shape(32, 32, 512, **kwargs)
test_shape(1, 32, 512, **kwargs)
test_shape(32, 512, 32, transpose=False, **kwargs)
test_shape(1, 512, 32, transpose=False, **kwargs)
def test_block_sparse_matmul_grad(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
x = mx.random.normal((4, 2, 32, 256))
w = mx.random.normal((4, 1, 32, 256))
w_hat, qw, s, b = quantize(w)
def f_ref(x, w, i1, i2):
return mx.block_sparse_mm(x, w, i1, i2).sum()
def f_test(x, qw, s, b, i1, i2):
return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(r1, r2, atol=1e-4))
g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices)
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()