diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 9a95e08ac..76c8bdf1a 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -162,8 +162,8 @@ array scaled_dot_product_attention( } // K, V must have matching number of heads (n_kv_heads); - size_t n_q_heads = queries.shape(-3); - size_t n_kv_heads = keys.shape(-3); + auto n_q_heads = queries.shape(-3); + auto n_kv_heads = keys.shape(-3); if (keys.shape(-3) != values.shape(-3)) { std::ostringstream msg; @@ -207,53 +207,43 @@ array scaled_dot_product_attention( bool needs_mask = mask.has_value(); auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s]( const std::vector& inputs) { - auto& q_tensor = inputs[0]; - auto& k_tensor = inputs[1]; - auto& v_tensor = inputs[2]; - auto q_scaled = multiply(array(scale, q_tensor.dtype()), q_tensor, s); - - auto tile_if_needs_repeat = - [n_q_heads, n_kv_heads](const array& arr, StreamOrDevice& s) -> array { - if (n_q_heads == n_kv_heads) - return arr; - int n_repeats = n_q_heads / n_kv_heads; - constexpr const int heads_axis = - 1; // heads axis, assumes tensors arranged as [0, 1, 2, 3] -> - // [Batch, Heads, Sequence, Hidden] - auto ret = repeat(arr, n_repeats, heads_axis, s); - return ret; - }; - auto k_tensor_tiled = tile_if_needs_repeat(k_tensor, s); - auto v_tensor_tiled = tile_if_needs_repeat(v_tensor, s); - - // dim check on k, v; repeat if untiled, since naive matmul will have - // dim mismatch for GQA (MQA could make use of broadcast) - auto k_transposed = transpose(k_tensor_tiled, {0, 1, 3, 2}, s); - auto s_tensor = matmul(q_scaled, k_transposed, s); - if (needs_mask) { - auto mask_tensor = inputs[3]; - s_tensor = add(s_tensor, mask_tensor, s); + auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); + int n_repeats = n_q_heads / n_kv_heads; + int B = q.shape(0); + int L = q.shape(2); + auto k = inputs[1]; + auto v = inputs[2]; + if (n_repeats > 1) { + q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s); + k = expand_dims(k, 2, s); + v = expand_dims(v, 2, s); } - auto p = astype( - softmax(astype(s_tensor, float32, s), std::vector{-1}, s), + auto scores = matmul(q, swapaxes(k, -1, -2, s), s); + if (needs_mask) { + scores = add(scores, inputs[3], s); + } + scores = astype( + softmax(astype(scores, float32, s), std::vector{-1}, s), final_type, s); - auto out_tensor = matmul(p, v_tensor_tiled, s); - return std::vector{out_tensor}; + auto out = matmul(scores, v, s); + if (n_repeats > 1) { + out = reshape(out, {B, n_q_heads, L, -1}, s); + } + return std::vector{out}; }; auto stream = to_stream(s); - - // current implementation use case: batch size 1, query sequence length 1, no - // mask. Likewise, requires head_dim == 128 constexpr const int supported_head_dim = 128; const size_t query_head_dim = q.shape(-1); const size_t query_sequence_length = q.shape(2); bool implementation_supports_use_case = batch_dim == 1 && query_sequence_length == 1 && !mask.has_value() && - query_head_dim == supported_head_dim && final_type != bfloat16; - - if (stream.device == Device::gpu && implementation_supports_use_case) { + query_head_dim == supported_head_dim && final_type != bfloat16 && + stream.device == Device::gpu; + // TODO, update routing conditions post further tuning + implementation_supports_use_case &= false; + if (implementation_supports_use_case) { auto out = array( out_shape, final_type,