From 5824626c0befcca45e59b9b6ebda89e5c5bf1e2e Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 22 Oct 2024 16:14:29 -0700 Subject: [PATCH] start --- benchmarks/python/sdpa_vector_bench.py | 65 ++++-- .../scaled_dot_product_attention.metal | 28 +-- mlx/backend/metal/kernels/sdpa_vector.h | 192 ++++++++++++++++-- .../metal/scaled_dot_product_attention.cpp | 119 +++++++++-- mlx/fast.cpp | 123 ++++++++++- mlx/fast.h | 15 ++ mlx/fast_primitives.h | 9 +- python/src/fast.cpp | 43 ++++ 8 files changed, 532 insertions(+), 62 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 058f495dc..c05fb8f39 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -1,16 +1,18 @@ -import argparse -import math - import mlx.core as mx +import numpy as np from time_utils import time_fn -L = 1024 +L = 30000 H = 32 H_k = 32 // 4 D = 128 def attention(q, k, v): + k = mx.quantize(k) + v = mx.quantize(v) + k = mx.dequantize(*k) + v = mx.dequantize(*v) B, Hq, L, D = q.shape _, Hk, S, _ = k.shape q = q.reshape(B, Hk, Hq // Hk, L, D) @@ -23,27 +25,54 @@ def attention(q, k, v): def sdpa(q, k, v): + k = mx.quantize(k) + v = mx.quantize(v) + k = mx.dequantize(*k) + v = mx.dequantize(*v) return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) -def time_self_attention_primitives(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)) - k = mx.random.uniform(shape=(1, H_k, L, D)) - v = mx.random.uniform(shape=(1, H_k, L, D)) - mx.eval(q, k, v) +def quant_sdpa(q, k, v): + k = mx.quantize(k) + v = mx.quantize(v) + return mx.fast.quantized_scaled_dot_product_attention(q, *k, *v, scale=1.0) + + +def time_self_attention_primitives(q, k, v): time_fn(attention, q, k, v) -def time_self_attention_sdpa(): - mx.random.seed(3) - q = mx.random.uniform(shape=(1, H, 1, D)) - k = mx.random.uniform(shape=(1, H_k, L, D)) - v = mx.random.uniform(shape=(1, H_k, L, D)) - mx.eval(q, k, v) +def time_self_attention_sdpa(q, k, v): time_fn(sdpa, q, k, v) +def time_self_attention_quant_sdpa(q, k, v): + time_fn(quant_sdpa, q, k, v) + + if __name__ == "__main__": - time_self_attention_sdpa() - time_self_attention_primitives() + mx.random.seed(3) + q = mx.random.uniform(shape=(1, H, 10, D)) + k = mx.random.uniform(shape=(1, H_k, L, D)) + v = mx.random.uniform(shape=(1, H_k, L, D)) + mx.eval(q, k, v) + + k_quant = mx.quantize(k) + v_quant = mx.quantize(v) + mx.eval(k_quant, v_quant) + + # time_self_attention_sdpa(q, k, v) + # time_self_attention_quant_sdpa(q, k_quant, v_quant) + # time_self_attention_primitives(q, k, v) + q_sdpa = quant_sdpa(q, k, v) + print(q_sdpa) + o_attention = attention(q, k, v) + print(o_attention) + np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5) + # o_sdpa = sdpa(q, k, v) + # print(o_sdpa) + # np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5) + # print(o_sdpa[..., :64]) + # print() + # print(o_attention[..., :64]) + # np.testing.assert_allclose(o_sdpa, o_attention) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 27fbb765b..3717febfb 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -927,19 +927,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); // SDPA vector instantiations #define instantiate_sdpa_vector(type, head_dim) \ - template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \ - [[kernel]] void sdpa_vector( \ - const device type* queries [[buffer(0)]], \ - const device type* keys [[buffer(1)]], \ - const device type* values [[buffer(2)]], \ - device type* out [[buffer(3)]], \ - const constant int& gqa_factor, \ - const constant int& N, \ - const constant size_t& k_stride, \ - const constant float& scale, \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); + instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) #define instantiate_sdpa_vector_heads(type) \ instantiate_sdpa_vector(type, 64) \ @@ -949,4 +937,18 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) + +// Quantized SDPA vector instantiations +#define instantiate_quant_sdpa_vector(type, head_dim) \ + instantiate_kernel("quant_sdpa_vector_" #type "_" #head_dim, quant_sdpa_vector, type, head_dim, 64, 4) + +#define instantiate_quant_sdpa_vector_heads(type) \ + instantiate_quant_sdpa_vector(type, 64) \ + instantiate_quant_sdpa_vector(type, 96) \ + instantiate_quant_sdpa_vector(type, 128) + +instantiate_quant_sdpa_vector_heads(float) +instantiate_quant_sdpa_vector_heads(bfloat16_t) +instantiate_quant_sdpa_vector_heads(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 4d4a9180b..e5961f97c 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -16,9 +16,11 @@ template const constant float& scale, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { + uint simd_lid [[thread_index_in_simdgroup]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { constexpr int BN = 32; - constexpr int BD = 32; + constexpr int BD = 4; constexpr int elem_per_thread = D / BD; const int stride = BN * D; @@ -36,9 +38,9 @@ template // Adjust positions const int head_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + simd_lid * elem_per_thread; - keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; - values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + queries += head_idx * D + quad_lid * elem_per_thread; + keys += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread; + values += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -53,7 +55,7 @@ template U sum_exp_score = 0; // For each key - for (int i = simd_gid; i < N; i += BN) { + for (int i = quad_gid; i < N; i += BN) { // Read the key for (int i = 0; i < elem_per_thread; i++) { k[i] = keys[i]; @@ -64,7 +66,7 @@ template for (int i = 0; i < elem_per_thread; i++) { score += q[i] * k[i]; } - score = simd_sum(score); + score = quad_sum(score); // Update the accumulators U new_max = max(max_score, score); @@ -88,9 +90,10 @@ template // Each thread has a partial part of the output so we need to combine them. // First let's communicate the max and sum_exp - if (simd_lid == 0) { - max_scores[simd_gid] = max_score; - sum_exp_scores[simd_gid] = sum_exp_score; + // Each quadgroup communicates it's max score + if (quad_lid == 0) { + max_scores[quad_gid] = max_score; + sum_exp_scores[quad_gid] = sum_exp_score; } threadgroup_barrier(mem_flags::mem_threadgroup); max_score = max_scores[simd_lid]; @@ -100,9 +103,174 @@ template // Now we need to aggregate all the outputs for (int i = 0; i < elem_per_thread; i++) { - outputs[simd_lid * BD + simd_gid] = o[i]; + // 128 threads with 32 values per thread + outputs[simd_gid * BN + simd_lid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +template +[[kernel]] void quant_sdpa_vector( + const device T* queries [[buffer(0)]], + const device uint32_t* keys [[buffer(1)]], + const device T* key_scales [[buffer(2)]], + const device T* key_biases [[buffer(3)]], + const device uint32_t* values [[buffer(4)]], + const device T* value_scales [[buffer(5)]], + const device T* value_biases [[buffer(6)]], + device T* out [[buffer(7)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& group_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int BN = 32; + constexpr int BD = 4; + constexpr int elem_per_thread = D / BD; + constexpr int pack_factor = 32 / bits; + + const int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U v[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + quad_lid * elem_per_thread; + + const int kv_idx = quad_gid * D + quad_lid * elem_per_thread; + const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor; + const int group_idx = kv_head_idx * group_stride + kv_idx / group_size; + keys += packed_idx; + key_scales += group_idx; + key_biases += group_idx; + values += packed_idx; + value_scales += group_idx; + value_biases += group_idx; + + out += head_idx * D + simd_gid * elem_per_thread; + + // Read the query and 0 the output accumulator + U query_sum = 0; + U shifts[4] = {1, 16, 256, 4096}; + for (int i = 0; i < elem_per_thread; i++) { + // Shift by the appropriate amount here + query_sum += queries[i]; + U shift = shifts[i % 4]; + q[i] = static_cast(scale) * queries[i] / shift; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (int i = quad_gid; i < N; i += BN) { + // Read the key + auto ks = (const device uint16_t*)keys; + for (int i = 0; i < elem_per_thread / 4; i++) { + k[4 * i] = ks[i] & 0x000f; + k[4 * i + 1] = ks[i] & 0x00f0; + k[4 * i + 2] = ks[i] & 0x0f00; + k[4 * i + 3] = ks[i] & 0xf000; + } + // All the keys in a set are in the same group + U key_scale = key_scales[0]; + U key_bias = key_biases[0]; + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = score * key_scale + query_sum * key_bias; + score = quad_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + U value_scale = value_scales[0]; + U value_bias = value_biases[0]; + + // Load the values + auto vs = (const device uint16_t*)values; + U s[4] = { + value_scale, + value_scale / 16.0f, + value_scale / 256.0f, + value_scale / 4096.0f}; + for (int i = 0; i < elem_per_thread / 4; i++) { + v[4 * i] = s[0] * (vs[i] & 0x000f) + value_bias; + v[4 * i + 1] = s[1] * (vs[i] & 0x00f0) + value_bias; + v[4 * i + 2] = s[2] * (vs[i] & 0x0f00) + value_bias; + v[4 * i + 3] = s[3] * (vs[i] & 0xf000) + value_bias; + } + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * v[i]; + } + + // Move the pointers to the next kv + keys += stride / pack_factor; + key_scales += stride / group_size; + key_biases += stride / group_size; + values += stride / pack_factor; + value_scales += stride / group_size; + value_biases += stride / group_size; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + // Each quadgroup communicates it's max score + if (quad_lid == 0) { + max_scores[quad_gid] = max_score; + sum_exp_scores[quad_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + // 128 threads with 32 values per thread + outputs[simd_gid * BN + simd_lid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; threadgroup_barrier(mem_flags::mem_threadgroup); } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 7a3fc03ba..e85c95acc 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -9,6 +9,8 @@ #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" #include "mlx/fast_primitives.h" +#include + namespace mlx::core::fast { namespace { @@ -163,7 +165,7 @@ void sdpa_vector( int N = k.shape(2); int B = q.shape(0) * q.shape(1); size_t stride = k.strides()[1]; - MTL::Size group_dims(1024, 1, 1); + MTL::Size group_dims(128, 1, 1); MTL::Size grid_dims(1, B, 1); // Get the kernel @@ -185,19 +187,67 @@ void sdpa_vector( compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } +void quant_sdpa_vector( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& k_scales, + const array& k_biases, + const array& v, + const array& v_scales, + const array& v_biases, + array& out, + float scale) { + // Set the kernel name + std::string kname; + kname.reserve(96); + kname += "quant_sdpa_vector_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + + // Compute the necessary sizes + int gqa_factor = q.shape(1) / k.shape(1); + int N = k.shape(2); + int B = q.shape(0) * q.shape(1); + size_t stride = k.strides()[1]; + size_t group_stride = k_scales.strides()[1]; + MTL::Size group_dims(128, 1, 1); + MTL::Size grid_dims(1, B, 1); + + // Get the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname); + compute_encoder->setComputePipelineState(kernel); + + // Set its arguments + compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(k_scales, 2); + compute_encoder.set_input_array(k_biases, 3); + compute_encoder.set_input_array(v, 4); + compute_encoder.set_input_array(v_scales, 5); + compute_encoder.set_input_array(v_biases, 6); + compute_encoder.set_output_array(out, 7); + compute_encoder->setBytes(&gqa_factor, sizeof(int), 8); + compute_encoder->setBytes(&N, sizeof(int), 9); + compute_encoder->setBytes(&stride, sizeof(size_t), 10); + compute_encoder->setBytes(&group_stride, sizeof(size_t), 11); + compute_encoder->setBytes(&scale, sizeof(float), 12); + + // Launch + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); +} + } // namespace void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, array& out) { - assert(inputs.size() == 3); - auto& s = stream(); auto& d = metal::device(s.device); - auto& q_pre = inputs[0]; - auto& k_pre = inputs[1]; - auto& v_pre = inputs[2]; auto& o = out; std::vector copies; @@ -236,11 +286,25 @@ void ScaledDotProductAttention::eval_gpu( return strides[3] == 1 && strides[2] == shape[3]; }; - // We are in vector mode ie single query - if (q_pre.shape(2) == 1) { + if (quantized_) { + auto& q_pre = inputs[0]; + auto& k_pre = inputs[1]; + auto& k_scales_pre = inputs[2]; + auto& k_biases_pre = inputs[3]; + auto& v_pre = inputs[4]; + auto& v_scales_pre = inputs[5]; + auto& v_biases_pre = inputs[6]; + + // Quantized should only be routed here for single queries + assert(q_pre.shape(2) == 1); + auto q = copy_unless(is_contiguous, q_pre); auto k = copy_unless(is_contiguous_except_seq_len, k_pre); + auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre); + auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre); auto v = copy_unless(is_contiguous_except_seq_len, v_pre); + auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre); + auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre); // Donate the query if possible if (q.is_donatable()) { @@ -249,17 +313,42 @@ void ScaledDotProductAttention::eval_gpu( o.set_data(allocator::malloc_or_wait(o.nbytes())); } - sdpa_vector(s, d, q, k, v, o, scale_); + quant_sdpa_vector( + s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_); + } - // Full attention mode + // Non-quantized else { - auto q = copy_unless(is_matrix_contiguous, q_pre); - auto k = copy_unless(is_matrix_contiguous, k_pre); - auto v = copy_unless(is_matrix_contiguous, v_pre); - o.set_data(allocator::malloc_or_wait(o.nbytes())); + assert(inputs.size() == 3); + auto& q_pre = inputs[0]; + auto& k_pre = inputs[1]; + auto& v_pre = inputs[2]; - sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); + // We are in vector mode ie single query + if (q_pre.shape(2) == 1) { + auto q = copy_unless(is_contiguous, q_pre); + auto k = copy_unless(is_contiguous_except_seq_len, k_pre); + auto v = copy_unless(is_contiguous_except_seq_len, v_pre); + + // Donate the query if possible + if (q.is_donatable()) { + o.move_shared_buffer(q); + } else { + o.set_data(allocator::malloc_or_wait(o.nbytes())); + } + + sdpa_vector(s, d, q, k, v, o, scale_); + } + // Full attention mode + else { + auto q = copy_unless(is_matrix_contiguous, q_pre); + auto k = copy_unless(is_matrix_contiguous, k_pre); + auto v = copy_unless(is_matrix_contiguous, v_pre); + o.set_data(allocator::malloc_or_wait(o.nbytes())); + + sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); + } } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 731a10bad..8573d0988 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -10,6 +10,8 @@ #include "mlx/ops.h" #include "mlx/transforms.h" +#include + namespace mlx::core::fast { std::vector Custom::vjp( @@ -648,7 +650,7 @@ array scaled_dot_product_attention( std::move(out_shape), final_type, std::make_shared( - stream, fallback, scale, false), + stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false), {q, k, v}); } @@ -662,7 +664,124 @@ array scaled_dot_product_attention( bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); - return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; + return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ && + quantized_ == a_other.quantized_; +} + +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + const float scale, + const std::optional& mask, + const int group_size, + const int bits, + StreamOrDevice s) { + int el_per_int = 32 / bits; + int out_dim = values.shape(-1) * el_per_int; + + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); + + std::cout << "group bits " << group_size << " " << bits << std::endl; + auto out_shape = std::vector( + {queries.shape(0), queries.shape(1), queries.shape(2), out_dim}); + auto stream = to_stream(s); + bool needs_mask = mask.has_value(); + auto fallback = + [scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s]( + const std::vector& inputs) -> std::vector { + int n_repeats = n_q_heads / n_kv_heads; + + auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); + + auto k = inputs[1]; + auto k_scales = inputs[2]; + auto k_biases = inputs[3]; + auto v = inputs[4]; + auto v_scales = inputs[5]; + auto v_biases = inputs[6]; + + int B = q.shape(0); + int L = q.shape(2); + + if (n_repeats > 1) { + q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s); + k = expand_dims(k, 2, s); + k_scales = expand_dims(k_scales, 2, s); + k_biases = expand_dims(k_biases, 2, s); + v = expand_dims(v, 2, s); + v_scales = expand_dims(v_scales, 2, s); + v_biases = expand_dims(v_biases, 2, s); + } + + array scores = quantized_matmul( + q, + k, + k_scales, + k_biases, + /*transpose=*/true, + /*group_size=*/group_size, + /*bits=*/bits, + s); + if (needs_mask) { + scores = add(scores, inputs[7], s); + } + scores = softmax(scores, std::vector{-1}, true, s); + array out = quantized_matmul( + scores, + v, + v_scales, + v_biases, + /*transpose=*/false, + /*group_size=*/group_size, + /*bits=*/bits, + s); + if (n_repeats > 1) { + out = reshape(out, {B, n_q_heads, L, -1}, s); + } + return std::vector{out}; + }; + + if (true) { + if (needs_mask) { + return fallback( + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases, + mask.value()})[0]; + } else { + return fallback( + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases})[0]; + } + } else { + return array( + std::move(out_shape), + queries.dtype(), + std::make_shared( + stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/true), + {queries, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases}); + } } array pack_and_quantize( diff --git a/mlx/fast.h b/mlx/fast.h index e1a876882..987aa8ce8 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -41,6 +41,21 @@ array scaled_dot_product_attention( const std::optional memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); +/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/ +array quantized_scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + const float scale, + const std::optional& mask = std::nullopt, + const int group_size = 64, + const int bits = 4, + StreamOrDevice s = {}); + std::tuple affine_quantize( const array& w, int group_size = 64, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 9233a1628..f97808d35 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -190,8 +190,12 @@ class ScaledDotProductAttention : public Custom { Stream stream, std::function(std::vector)> fallback, const float scale, - const bool needs_mask) - : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {} + const bool needs_mask, + const bool quantized) + : Custom(stream, fallback), + scale_(scale), + needs_mask_(needs_mask), + quantized_(quantized) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -212,6 +216,7 @@ class ScaledDotProductAttention : public Custom { std::function(std::vector)> fallback_; float scale_; bool needs_mask_; + bool quantized_; }; class AffineQuantize : public Custom { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 758a27530..83981481a 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -150,6 +150,49 @@ void init_fast(nb::module_& parent_module) { array: The output array. )pbdoc"); + m.def( + "quantized_scaled_dot_product_attention", + &fast::quantized_scaled_dot_product_attention, + "q"_a, + "k"_a, + "k_scales"_a, + "k_biases"_a, + "v"_a, + "v_scales"_a, + "v_biases"_a, + nb::kw_only(), + "scale"_a, + "mask"_a = nb::none(), + "group_size"_a = 64, + "bits"_a = 4, + "stream"_a = nb::none(), + nb::sig( + "def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. + + Supports: + + * `Multi-Head Attention `_ + * `Grouped Query Attention `_ + * `Multi-Query Attention `_ + + Note: The softmax operation is performed in ``float32`` regardless of + the input precision. + + Note: For Grouped Query Attention and Multi-Query Attention, the ``k`` + and ``v`` inputs should not be pre-tiled to match ``q``. + + Args: + q (array): Input query array. + k (array): Input keys array. + v (array): Input values array. + scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) + mask (array, optional): An additive mask to apply to the query-key scores. + Returns: + array: The output array. + )pbdoc"); + m.def( "affine_quantize", nb::overload_cast<