mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 03:22:54 +08:00
[Feature]Add no parallel-m qmm kernel to improve decoding performance
This commit is contained in:
parent
b529515eb1
commit
b2f0ebe9ee
@ -688,6 +688,82 @@ METAL_FUNC void qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void qmv_no_parallel_m_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* biases,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& m_size,
|
||||
const constant int& k_size,
|
||||
const constant int& n_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
constexpr int max_batch = 10;
|
||||
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[max_batch * results_per_simdgroup] = {0};
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = k_size * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g = k_size / group_size;
|
||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||
simd_gid * results_per_simdgroup;
|
||||
|
||||
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
// x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||
// y += tid.x * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < k_size; k += block_size) {
|
||||
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
|
||||
for (int col = 0; col < m_size; col++) {
|
||||
auto x_temp = x + col * k_size + simd_lid * values_per_thread + k;
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
|
||||
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
biases += block_size / group_size;
|
||||
// x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
for (int col = 0; col < m_size; col++) {
|
||||
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
|
||||
auto y_temp = y + col * n_size + out_row;
|
||||
if (simd_lid == 0) {
|
||||
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void qmv_impl(
|
||||
const device uint32_t* w,
|
||||
@ -1410,6 +1486,59 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
[[kernel]] void qmv_no_parallel_m(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
const device T* x [[buffer(3)]],
|
||||
device T* y [[buffer(4)]],
|
||||
const constant int& m_size [[buffer(5)]],
|
||||
const constant int& k_size [[buffer(6)]],
|
||||
const constant int& n_size [[buffer(7)]],
|
||||
const constant int& x_batch_ndims [[buffer(8)]],
|
||||
const constant int* x_shape [[buffer(9)]],
|
||||
const constant int64_t* x_strides [[buffer(10)]],
|
||||
const constant int& w_batch_ndims [[buffer(11)]],
|
||||
const constant int* w_shape [[buffer(12)]],
|
||||
const constant int64_t* w_strides [[buffer(13)]],
|
||||
const constant int64_t* s_strides [[buffer(14)]],
|
||||
const constant int64_t* b_strides [[buffer(15)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
n_size * m_size,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
}
|
||||
qmv_no_parallel_m_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
x,
|
||||
y,
|
||||
m_size,
|
||||
k_size,
|
||||
n_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
|
@ -80,6 +80,7 @@
|
||||
|
||||
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv_no_parallel_m, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
|
||||
|
@ -244,6 +244,59 @@ void qmv(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void qmv_no_parallel_m(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int bn = 128;
|
||||
// int bk = 32;
|
||||
MTL::Size group_dims(2, 1, 1);
|
||||
MTL::Size grid_dims((N + bn - 1) / bn, 1, B);
|
||||
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
// bool fast = N % bn == 0 && K % 512 == 0;
|
||||
concatenate(
|
||||
kname,
|
||||
"qmv_no_parallel_m_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
"_b_",
|
||||
bits,
|
||||
B > 1 ? "_batch_1" : "_batch_0");
|
||||
auto template_def = get_template_definition(
|
||||
kname, "qmv_no_parallel_m", type_string, group_size, bits, B > 1);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder.set_bytes(M, 5);
|
||||
compute_encoder.set_bytes(K, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void qvm_split_k(
|
||||
const array& x,
|
||||
const array& w,
|
||||
@ -818,7 +871,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Run of the mill qmv
|
||||
if (transpose_) {
|
||||
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
||||
qmv_no_parallel_m(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
||||
return;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user