mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add split_k qvm for long context (#1564)
* Add splitk qvm * configurable splitk * tuning * remove extra instantiation * remove refactor * separate test * cpu tolerance
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -148,6 +149,125 @@ void launch_qmm(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void qvm_split_k(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
int D,
|
||||
int O,
|
||||
int B,
|
||||
int N,
|
||||
const Stream& s) {
|
||||
int split_k = D > 8192 ? 32 : 8;
|
||||
int split_D = (D + split_k - 1) / split_k;
|
||||
N *= split_k;
|
||||
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
// Ensure that the last two dims are row contiguous.
|
||||
// TODO: Check if we really need this for x as well...
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto x_shape = x.shape();
|
||||
auto x_strides = x.strides();
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
auto w_shape = w.shape();
|
||||
auto w_strides = w.strides();
|
||||
auto s_strides = scales.strides();
|
||||
auto b_strides = biases.strides();
|
||||
|
||||
// Add split_k dim with reshapes
|
||||
x_shape.insert(x_shape.end() - 2, split_k);
|
||||
x_shape.back() /= split_k;
|
||||
x_strides.insert(x_strides.end() - 2, split_D);
|
||||
x_strides[x.ndim() - 1] = split_D;
|
||||
x_batch_ndims += 1;
|
||||
|
||||
w_shape.insert(w_shape.end() - 2, split_k);
|
||||
w_shape[w.ndim() - 1] /= split_k;
|
||||
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
|
||||
w_batch_ndims += 1;
|
||||
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
|
||||
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
|
||||
|
||||
int final_block_size = D - (split_k - 1) * split_D;
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto temp_shape = out.shape();
|
||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
d.add_temporary(intermediate, s.index);
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_"
|
||||
<< bits << "_spk_" << split_k;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qvm_split_k", type_string, group_size, bits, split_k);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(intermediate, 4);
|
||||
compute_encoder->setBytes(&split_D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
|
||||
set_vector_bytes(compute_encoder, x_shape, 8);
|
||||
set_vector_bytes(compute_encoder, x_strides, 9);
|
||||
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
|
||||
set_vector_bytes(compute_encoder, w_shape, 11);
|
||||
set_vector_bytes(compute_encoder, w_strides, 12);
|
||||
set_vector_bytes(compute_encoder, s_strides, 13);
|
||||
set_vector_bytes(compute_encoder, b_strides, 14);
|
||||
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
|
||||
int axis = intermediate.ndim() - 3;
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce,
|
||||
{intermediate.shape(axis)},
|
||||
{intermediate.strides(axis)});
|
||||
strided_reduce_general_dispatch(
|
||||
intermediate, out, "sum", plan, {axis}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
void qmm_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
@@ -211,7 +331,9 @@ void qmm_op(
|
||||
aligned = true;
|
||||
}
|
||||
} else {
|
||||
if (B < 4) {
|
||||
if (B < 4 && D >= 1024 && !gather) {
|
||||
return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s);
|
||||
} else if (B < 4) {
|
||||
name += "qvm";
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
|
||||
Reference in New Issue
Block a user