mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add split_k qvm for long context (#1564)
* Add splitk qvm * configurable splitk * tuning * remove extra instantiation * remove refactor * separate test * cpu tolerance
This commit is contained in:
@@ -650,8 +650,8 @@ METAL_FUNC void qvm_impl(
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
const int in_vec_size,
|
||||
const int out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -1298,6 +1298,61 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
[[kernel]] void qvm_split_k(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant size_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant size_t* w_strides [[buffer(12)]],
|
||||
const constant size_t* s_strides [[buffer(13)]],
|
||||
const constant size_t* b_strides [[buffer(14)]],
|
||||
const constant int& final_block_size [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
|
||||
// When (in_vec_size % split_k != 0) the final block needs to be smaller
|
||||
int in_vec_size_adj =
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
qvm_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
in_vec_size_adj,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
|
||||
@@ -51,6 +51,15 @@
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
@@ -84,11 +93,16 @@
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits)
|
||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
|
||||
Reference in New Issue
Block a user