diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index e8f1c18a2..f6d3671b4 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -650,8 +650,8 @@ METAL_FUNC void qvm_impl( const device T* biases, const device T* x, device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, + const int in_vec_size, + const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1298,6 +1298,61 @@ template simd_lid); } +template +[[kernel]] void qvm_split_k( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& final_block_size [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + template < typename T, const int group_size, diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 4720a3cda..5751d953f 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -51,6 +51,15 @@ D, \ batched) +#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \ + name, \ + type, \ + group_size, \ + bits, \ + split_k) + #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 0) @@ -84,11 +93,16 @@ instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \ instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0) +#define instantiate_quantized_all_splitk(type, group_size, bits) \ + instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ + instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) + #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \ - instantiate_quantized_all_quad(type, group_size, bits) + instantiate_quantized_all_quad(type, group_size, bits) \ + instantiate_quantized_all_splitk(type, group_size, bits) #define instantiate_quantized_types(group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 30828da70..4a74f2925 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" @@ -148,6 +149,125 @@ void launch_qmm( d.add_temporaries(std::move(copies), s.index); } +void qvm_split_k( + const std::vector& inputs, + array& out, + int group_size, + int bits, + int D, + int O, + int B, + int N, + const Stream& s) { + int split_k = D > 8192 ? 32 : 8; + int split_D = (D + split_k - 1) / split_k; + N *= split_k; + + int bo = 64; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, 2, 1); + MTL::Size grid_dims = MTL::Size(O / bo, B, N); + + auto& x_pre = inputs[0]; + auto& w_pre = inputs[1]; + auto& scales_pre = inputs[2]; + auto& biases_pre = inputs[3]; + + // Ensure that the last two dims are row contiguous. + // TODO: Check if we really need this for x as well... + std::vector copies; + auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { + auto stride_0 = arr.strides()[arr.ndim() - 2]; + auto stride_1 = arr.strides()[arr.ndim() - 1]; + if (stride_0 == arr.shape(-1) && stride_1 == 1) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + return arr_copy; + } + }; + auto x = ensure_row_contiguous_last_dims(x_pre); + auto w = ensure_row_contiguous_last_dims(w_pre); + auto scales = ensure_row_contiguous_last_dims(scales_pre); + auto biases = ensure_row_contiguous_last_dims(biases_pre); + + int x_batch_ndims = x.ndim() - 2; + auto x_shape = x.shape(); + auto x_strides = x.strides(); + int w_batch_ndims = w.ndim() - 2; + auto w_shape = w.shape(); + auto w_strides = w.strides(); + auto s_strides = scales.strides(); + auto b_strides = biases.strides(); + + // Add split_k dim with reshapes + x_shape.insert(x_shape.end() - 2, split_k); + x_shape.back() /= split_k; + x_strides.insert(x_strides.end() - 2, split_D); + x_strides[x.ndim() - 1] = split_D; + x_batch_ndims += 1; + + w_shape.insert(w_shape.end() - 2, split_k); + w_shape[w.ndim() - 1] /= split_k; + w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1)); + w_batch_ndims += 1; + s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); + b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1)); + + int final_block_size = D - (split_k - 1) * split_D; + + auto& d = metal::device(s.device); + + auto temp_shape = out.shape(); + temp_shape.insert(temp_shape.end() - 2, split_k); + array intermediate(temp_shape, x.dtype(), nullptr, {}); + intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + d.add_temporary(intermediate, s.index); + + std::ostringstream kname; + auto type_string = get_type_string(x.dtype()); + kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_" + << bits << "_spk_" << split_k; + auto template_def = get_template_definition( + kname.str(), "qvm_split_k", type_string, group_size, bits, split_k); + + // Encode and dispatch kernel + auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(intermediate, 4); + compute_encoder->setBytes(&split_D, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7); + set_vector_bytes(compute_encoder, x_shape, 8); + set_vector_bytes(compute_encoder, x_strides, 9); + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10); + set_vector_bytes(compute_encoder, w_shape, 11); + set_vector_bytes(compute_encoder, w_strides, 12); + set_vector_bytes(compute_encoder, s_strides, 13); + set_vector_bytes(compute_encoder, b_strides, 14); + compute_encoder->setBytes(&final_block_size, sizeof(int), 15); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + d.add_temporaries(std::move(copies), s.index); + + int axis = intermediate.ndim() - 3; + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, + {intermediate.shape(axis)}, + {intermediate.strides(axis)}); + strided_reduce_general_dispatch( + intermediate, out, "sum", plan, {axis}, compute_encoder, d, s); +} + void qmm_op( const std::vector& inputs, array& out, @@ -211,7 +331,9 @@ void qmm_op( aligned = true; } } else { - if (B < 4) { + if (B < 4 && D >= 1024 && !gather) { + return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s); + } else if (B < 4) { name += "qvm"; int bo = 64; int bd = 32; diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 2b5251847..607f7ef24 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -163,6 +163,31 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qvm_splitk(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [128, 64, 32], # group_size + [2, 4, 8], # bits + [128], # M + [16384], # N + [1, 3], # B + ) + for group_size, bits, M, N, B in tests: + with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): + x_shape = (1, N) if B == 0 else (B, 1, N) + w_shape = (N, M) if B == 0 else (B, N, M) + x = mx.random.normal(shape=x_shape, key=k1) + w = mx.random.normal(shape=w_shape, key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, False, group_size, bits + ) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 2e-3) + def test_throw(self): x = mx.random.normal(shape=(10, 512)) w = mx.random.normal(shape=(32, 512))