Add split_k qvm for long context (#1564)

* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
This commit is contained in:
Alex Barron 2024-11-05 11:25:19 -08:00 committed by GitHub
parent 248431eb3c
commit 26be608470
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 220 additions and 4 deletions

View File

@ -650,8 +650,8 @@ METAL_FUNC void qvm_impl(
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const int in_vec_size,
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@ -1298,6 +1298,61 @@ template <typename T, const int group_size, const int bits, bool batched>
simd_lid);
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& final_block_size [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
// When (in_vec_size % split_k != 0) the final block needs to be smaller
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
qvm_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <
typename T,
const int group_size,

View File

@ -51,6 +51,15 @@
D, \
batched)
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
name, \
type, \
group_size, \
bits, \
split_k)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@ -84,11 +93,16 @@
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits)
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@ -148,6 +149,125 @@ void launch_qmm(
d.add_temporaries(std::move(copies), s.index);
}
void qvm_split_k(
const std::vector<array>& inputs,
array& out,
int group_size,
int bits,
int D,
int O,
int B,
int N,
const Stream& s) {
int split_k = D > 8192 ? 32 : 8;
int split_D = (D + split_k - 1) / split_k;
N *= split_k;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
// Ensure that the last two dims are row contiguous.
// TODO: Check if we really need this for x as well...
std::vector<array> copies;
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
int x_batch_ndims = x.ndim() - 2;
auto x_shape = x.shape();
auto x_strides = x.strides();
int w_batch_ndims = w.ndim() - 2;
auto w_shape = w.shape();
auto w_strides = w.strides();
auto s_strides = scales.strides();
auto b_strides = biases.strides();
// Add split_k dim with reshapes
x_shape.insert(x_shape.end() - 2, split_k);
x_shape.back() /= split_k;
x_strides.insert(x_strides.end() - 2, split_D);
x_strides[x.ndim() - 1] = split_D;
x_batch_ndims += 1;
w_shape.insert(w_shape.end() - 2, split_k);
w_shape[w.ndim() - 1] /= split_k;
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
w_batch_ndims += 1;
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
int final_block_size = D - (split_k - 1) * split_D;
auto& d = metal::device(s.device);
auto temp_shape = out.shape();
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_"
<< bits << "_spk_" << split_k;
auto template_def = get_template_definition(
kname.str(), "qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder->setBytes(&split_D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
set_vector_bytes(compute_encoder, x_shape, 8);
set_vector_bytes(compute_encoder, x_strides, 9);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, w_shape, 11);
set_vector_bytes(compute_encoder, w_strides, 12);
set_vector_bytes(compute_encoder, s_strides, 13);
set_vector_bytes(compute_encoder, b_strides, 14);
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
int axis = intermediate.ndim() - 3;
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce,
{intermediate.shape(axis)},
{intermediate.strides(axis)});
strided_reduce_general_dispatch(
intermediate, out, "sum", plan, {axis}, compute_encoder, d, s);
}
void qmm_op(
const std::vector<array>& inputs,
array& out,
@ -211,7 +331,9 @@ void qmm_op(
aligned = true;
}
} else {
if (B < 4) {
if (B < 4 && D >= 1024 && !gather) {
return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s);
} else if (B < 4) {
name += "qvm";
int bo = 64;
int bd = 32;

View File

@ -163,6 +163,31 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm_splitk(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
[128], # M
[16384], # N
[1, 3], # B
)
for group_size, bits, M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))