diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 9beec77b13..5292f8685b 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -22,18 +22,14 @@ instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) // Quantized SDPA vector instantiations -#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \ +#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \ instantiate_kernel( \ - #name "_" #type "_" #head_dim "_" #group_size "_" #bits, \ - name, type, head_dim, group_size, bits) - -#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \ - instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits) \ - instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits) + "quant_sdpa_vector_2pass_1_" #type "_" #head_dim "_" #group_size "_" #bits, \ + quant_sdpa_vector_2pass_1, type, head_dim, group_size, bits) #define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \ - instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4) \ - instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8) + instantiate_quant_sdpa_vector(type, heads, group_size, 4) \ + instantiate_quant_sdpa_vector(type, heads, group_size, 8) #define instantiate_quant_sdpa_vector_group_size(type, heads) \ instantiate_quant_sdpa_vector_bits(type, heads, 32) \ @@ -42,7 +38,6 @@ instantiate_sdpa_vector_heads(float16_t) #define instantiate_quant_sdpa_vector_heads(type) \ instantiate_quant_sdpa_vector_group_size(type, 64) \ - instantiate_quant_sdpa_vector_group_size(type, 96) \ instantiate_quant_sdpa_vector_group_size(type, 128) instantiate_quant_sdpa_vector_heads(float) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 49eb35b1fe..a03ee3d9b9 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -174,147 +174,6 @@ METAL_FUNC void load_values( } } -template -[[kernel]] void quant_sdpa_vector( - const device T* queries [[buffer(0)]], - const device uint32_t* keys [[buffer(1)]], - const device T* key_scales [[buffer(2)]], - const device T* key_biases [[buffer(3)]], - const device uint32_t* values [[buffer(4)]], - const device T* value_scales [[buffer(5)]], - const device T* value_biases [[buffer(6)]], - device T* out [[buffer(7)]], - const constant int& gqa_factor, - const constant int& N, - const constant size_t& k_stride, - const constant size_t& group_stride, - const constant float& scale, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { - constexpr int BN = 32; - constexpr int BD = 4; - constexpr int elem_per_thread = D / BD; - constexpr int pack_factor = 32 / bits; - - const int stride = BN * D; - - typedef float U; - - thread U q[elem_per_thread]; - thread U k[elem_per_thread]; - thread U v[elem_per_thread]; - thread U o[elem_per_thread]; - - threadgroup U outputs[BN * BD]; - threadgroup U max_scores[BN]; - threadgroup U sum_exp_scores[BN]; - - // Adjust positions - const int head_idx = tid.y; - const int kv_head_idx = head_idx / gqa_factor; - queries += head_idx * D + quad_lid * elem_per_thread; - - const int kv_idx = quad_gid * D + quad_lid * elem_per_thread; - const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor; - const int group_idx = kv_head_idx * group_stride + kv_idx / group_size; - keys += packed_idx; - key_scales += group_idx; - key_biases += group_idx; - values += packed_idx; - value_scales += group_idx; - value_biases += group_idx; - - out += head_idx * D + simd_gid * elem_per_thread; - - // Read the query and 0 the output accumulator - U query_sum = load_queries( - queries, q, static_cast(scale)); - for (int i = 0; i < elem_per_thread; i++) { - o[i] = 0; - } - - U max_score = -INFINITY; - U sum_exp_score = 0; - - // For each key - for (int i = quad_gid; i < N; i += BN) { - load_keys(keys, k); - - // Assume D % group_size == 0 so all the keys are in the same group - U key_scale = key_scales[0]; - U key_bias = key_biases[0]; - - // Compute the i-th score - U score = 0; - for (int i = 0; i < elem_per_thread; i++) { - score += q[i] * k[i]; - } - score = score * key_scale + query_sum * key_bias; - score = quad_sum(score); - - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); - - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; - - U value_scale = value_scales[0]; - U value_bias = value_biases[0]; - - // Load the values - load_values(values, v, value_scale, value_bias); - - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] = o[i] * factor + exp_score * v[i]; - } - - // Move the pointers to the next kv - keys += stride / pack_factor; - key_scales += stride / group_size; - key_biases += stride / group_size; - values += stride / pack_factor; - value_scales += stride / group_size; - value_biases += stride / group_size; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Each thread has a partial part of the output so we need to combine them. - - // First let's communicate the max and sum_exp - // Each quadgroup communicates it's max score - if (quad_lid == 0) { - max_scores[quad_gid] = max_score; - sum_exp_scores[quad_gid] = sum_exp_score; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - max_score = max_scores[simd_lid]; - U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); - sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); - - // Now we need to aggregate all the outputs - for (int i = 0; i < elem_per_thread; i++) { - // 128 threads with 32 values per thread - outputs[simd_gid * BN + simd_lid] = o[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // And write the output - if (simd_lid == 0) { - for (int i = 0; i < elem_per_thread; i++) { - out[i] = static_cast(o[i]); - } - } -} - template [[kernel]] void sdpa_vector_2pass_1( const device T* queries [[buffer(0)]], diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 0f87e6027c..d82d55bac3 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -242,65 +242,6 @@ void sdpa_vector_2pass( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } -void quant_sdpa_vector( - const Stream& s, - metal::Device& d, - const array& q, - const array& k, - const array& k_scales, - const array& k_biases, - const array& v, - const array& v_scales, - const array& v_biases, - array& out, - float scale, - int group_size, - int bits) { - // Set the kernel name - std::string kname; - kname.reserve(96); - kname += "quant_sdpa_vector_"; - kname += get_type_string(q.dtype()); - kname += "_"; - kname += std::to_string(q.shape(-1)); - kname += "_"; - kname += std::to_string(group_size); - kname += "_"; - kname += std::to_string(bits); - - // Compute the necessary sizes - int gqa_factor = q.shape(1) / k.shape(1); - int N = k.shape(2); - int B = q.shape(0) * q.shape(1); - size_t stride = k.strides()[1]; - size_t group_stride = k_scales.strides()[1]; - MTL::Size group_dims(128, 1, 1); - MTL::Size grid_dims(1, B, 1); - - // Get the kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname); - compute_encoder.set_compute_pipeline_state(kernel); - - // Set its arguments - compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); - compute_encoder.set_input_array(k, 1); - compute_encoder.set_input_array(k_scales, 2); - compute_encoder.set_input_array(k_biases, 3); - compute_encoder.set_input_array(v, 4); - compute_encoder.set_input_array(v_scales, 5); - compute_encoder.set_input_array(v_biases, 6); - compute_encoder.set_output_array(out, 7); - compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8); - compute_encoder.set_bytes(&N, sizeof(int), 9); - compute_encoder.set_bytes(&stride, sizeof(size_t), 10); - compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11); - compute_encoder.set_bytes(&scale, sizeof(float), 12); - - // Launch - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); -} - void quant_sdpa_vector_2pass( const Stream& s, metal::Device& d, diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 37a6ec47bc..3848bd58cc 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -516,15 +516,11 @@ bool RoPE::is_equivalent(const Primitive& other) const { offset_ == a_other.offset_ && forward_ == a_other.forward_); } -/** Computes: O = softmax(Q @ K.T) @ V **/ -array scaled_dot_product_attention( +void check_sdpa_arguments( const array& queries, const array& keys, const array& values, - const float scale, - const std::optional& mask, - const std::optional memory_efficient_threshold, - StreamOrDevice s) { + const std::optional& mask) { for (const auto& tensor : {queries, keys, values}) { if (tensor.ndim() != 4) { std::ostringstream msg; @@ -550,14 +546,6 @@ array scaled_dot_product_attention( } } - // Q, K must have matching last dims (d_k aka 'head_dim'); - if (queries.shape(-1) != keys.shape(-1)) { - std::ostringstream msg; - msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " - << queries.shape() << " for keys shape " << keys.shape() << "."; - throw std::invalid_argument(msg.str()); - } - // K, V must have matching number of heads (n_kv_heads); auto n_q_heads = queries.shape(-3); auto n_kv_heads = keys.shape(-3); @@ -577,6 +565,26 @@ array scaled_dot_product_attention( << n_q_heads << " for n_kv_heads " << n_kv_heads << "."; throw std::invalid_argument(msg.str()); } +} + +/** Computes: O = softmax(Q @ K.T) @ V **/ +array scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& values, + const float scale, + const std::optional& mask, + const std::optional memory_efficient_threshold, + StreamOrDevice s) { + check_sdpa_arguments(queries, keys, values, mask); + + // Q, K must have matching last dims (d_k aka 'head_dim'); + if (queries.shape(-1) != keys.shape(-1)) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " + << queries.shape() << " for keys shape " << keys.shape() << "."; + throw std::invalid_argument(msg.str()); + } auto final_type = result_type(queries, keys, values); if (!issubdtype(final_type, floating)) { @@ -590,6 +598,9 @@ array scaled_dot_product_attention( auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); + /* generic implementation for use cases that Metal implementation does not * support. For non-supported cases listed below, use MLX primitives: * * CPU implementation @@ -696,6 +707,25 @@ array quantized_scaled_dot_product_attention( const int bits, StreamOrDevice s) { int el_per_int = 32 / bits; + + check_sdpa_arguments(queries, keys, values, mask); + + // Q, K must have matching last dims (d_k aka 'head_dim'); + if (queries.shape(-1) != keys.shape(-1) * el_per_int) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape " + << queries.shape() << " for keys shape " << keys.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto final_type = result_type(queries, key_scales, value_scales); + if (!issubdtype(final_type, floating)) { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] Received unsupported type " + << final_type << "."; + throw std::invalid_argument(msg.str()); + } + int out_dim = values.shape(-1) * el_per_int; auto n_q_heads = queries.shape(-3); @@ -760,8 +790,9 @@ array quantized_scaled_dot_product_attention( return std::vector{out}; }; + int query_head_dim = queries.shape(-1); int L = queries.shape(2); - if (L > 1) { + if (L > 1 && query_head_dim != 64 && query_head_dim != 128) { if (needs_mask) { return fallback( {queries,