diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 4abef4c49..c668d9d8c 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -1,11 +1,11 @@ #include -#include "mlx/backend/metal/kernels/sdpa_vector.h" +// clang-format off #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/sdpa_vector.h" using namespace metal; -// clang-format off // SDPA vector instantiations #define instantiate_sdpa_vector_aggregation(type, value_dim) \ instantiate_kernel( \ diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index a83989b17..c4c0f6456 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -7,6 +7,8 @@ using namespace metal; constant bool has_mask [[function_constant(20)]]; constant bool query_transposed [[function_constant(21)]]; constant bool do_causal [[function_constant(22)]]; +constant bool bool_mask [[function_constant(23)]]; +constant bool float_mask [[function_constant(24)]]; template [[kernel]] void sdpa_vector( @@ -14,17 +16,21 @@ template const device T* keys [[buffer(1)]], const device T* values [[buffer(2)]], device T* out [[buffer(3)]], - const constant int& gqa_factor, - const constant int& N, - const constant size_t& k_head_stride, - const constant size_t& k_seq_stride, - const constant size_t& v_head_stride, - const constant size_t& v_seq_stride, - const constant float& scale, - const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], - const constant int& mask_q_seq_stride [[function_constant(has_mask)]], - const constant int& mask_head_stride [[function_constant(has_mask)]], + const constant int& gqa_factor [[buffer(4)]], + const constant int& N [[buffer(5)]], + const constant size_t& k_head_stride [[buffer(6)]], + const constant size_t& k_seq_stride [[buffer(7)]], + const constant size_t& v_head_stride [[buffer(8)]], + const constant size_t& v_seq_stride [[buffer(9)]], + const constant float& scale [[buffer(10)]], + const device bool* bmask [[buffer(11), function_constant(bool_mask)]], + const device T* fmask [[buffer(12), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(13), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(14), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(15), function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -58,8 +64,12 @@ template simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + simd_lid * v_per_thread; - if (has_mask) { - mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + if (bool_mask) { + bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } @@ -81,8 +91,8 @@ template bool use_key = true; if (do_causal) { use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); - } else if (has_mask) { - use_key = mask[0]; + } else if (bool_mask) { + use_key = bmask[0]; } if (use_key) { // Read the key @@ -96,6 +106,9 @@ template score += q[j] * k[j]; } score = simd_sum(score); + if (float_mask) { + score += max(Limits::finite_min, static_cast(fmask[0])); + } // Update the accumulators U new_max = max(max_score, score); @@ -114,8 +127,11 @@ template // Move the pointers to the next kv keys += inner_k_stride; values += inner_v_stride; - if (has_mask) { - mask += BN * mask_kv_seq_stride; + if (bool_mask) { + bmask += BN * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * mask_kv_seq_stride; } } @@ -156,17 +172,21 @@ template device float* out [[buffer(3)]], device float* sums [[buffer(4)]], device float* maxs [[buffer(5)]], - const constant int& gqa_factor, - const constant int& N, - const constant size_t& k_head_stride, - const constant size_t& k_seq_stride, - const constant size_t& v_head_stride, - const constant size_t& v_seq_stride, - const constant float& scale, - const device bool* mask [[function_constant(has_mask)]], - const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], - const constant int& mask_q_seq_stride [[function_constant(has_mask)]], - const constant int& mask_head_stride [[function_constant(has_mask)]], + const constant int& gqa_factor [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant size_t& k_head_stride [[buffer(8)]], + const constant size_t& k_seq_stride [[buffer(9)]], + const constant size_t& v_head_stride [[buffer(10)]], + const constant size_t& v_seq_stride [[buffer(11)]], + const constant float& scale [[buffer(12)]], + const device bool* bmask [[buffer(13), function_constant(bool_mask)]], + const device T* fmask [[buffer(14), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(15), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(16), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(17), function_constant(has_mask)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -204,8 +224,13 @@ template values += kv_head_idx * v_head_stride + (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; - if (has_mask) { - mask += head_idx * mask_head_stride + + if (bool_mask) { + bmask += head_idx * mask_head_stride + + (block_idx * BN + simd_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } @@ -228,8 +253,8 @@ template bool use_key = true; if (do_causal) { use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); - } else if (has_mask) { - use_key = mask[0]; + } else if (bool_mask) { + use_key = bmask[0]; } if (use_key) { // Read the key @@ -243,6 +268,9 @@ template score += q[i] * k[i]; } score = simd_sum(score); + if (float_mask) { + score += fmask[0]; + } // Update the accumulators U new_max = max(max_score, score); @@ -261,8 +289,11 @@ template // Move the pointers to the next kv keys += blocks * inner_k_stride; values += blocks * inner_v_stride; - if (has_mask) { - mask += BN * blocks * mask_kv_seq_stride; + if (bool_mask) { + bmask += BN * blocks * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * blocks * mask_kv_seq_stride; } } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 094756ac5..845962d01 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -163,14 +163,18 @@ void sdpa_vector( MTL::Size grid_dims(B, q.shape(2), 1); bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, }; std::string hash_name = kname; - hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; @@ -194,15 +198,15 @@ void sdpa_vector( compute_encoder.set_bytes(scale, 10); if (has_mask) { auto& m = *mask; - compute_encoder.set_input_array(m, 11); + compute_encoder.set_input_array(m, 11 + float_mask); auto nd = m.ndim(); int32_t kv_seq_stride = nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; - compute_encoder.set_bytes(kv_seq_stride, 12); - compute_encoder.set_bytes(q_seq_stride, 13); - compute_encoder.set_bytes(head_stride, 14); + compute_encoder.set_bytes(kv_seq_stride, 13); + compute_encoder.set_bytes(q_seq_stride, 14); + compute_encoder.set_bytes(head_stride, 15); } // Launch @@ -260,14 +264,18 @@ void sdpa_vector_2pass( d.add_temporary(maxs, s.index); bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; bool query_transposed = !q.flags().row_contiguous; metal::MTLFCList func_consts = { {&has_mask, MTL::DataType::DataTypeBool, 20}, {&query_transposed, MTL::DataType::DataTypeBool, 21}, {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, }; std::string hash_name = kname; - hash_name += has_mask ? "_mask" : "_nomask"; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; @@ -293,15 +301,15 @@ void sdpa_vector_2pass( compute_encoder.set_bytes(scale, 12); if (has_mask) { auto& m = *mask; - compute_encoder.set_input_array(m, 13); + compute_encoder.set_input_array(m, 13 + float_mask); auto nd = m.ndim(); int32_t kv_seq_stride = nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; - compute_encoder.set_bytes(kv_seq_stride, 14); - compute_encoder.set_bytes(q_seq_stride, 15); - compute_encoder.set_bytes(head_stride, 16); + compute_encoder.set_bytes(kv_seq_stride, 15); + compute_encoder.set_bytes(q_seq_stride, 16); + compute_encoder.set_bytes(head_stride, 17); } // Launch diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 4799fedc3..77210f713 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -739,8 +739,6 @@ array scaled_dot_product_attention( const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); - const bool sdpa_vector_supported_mask = - !has_mask || has_bool_mask || do_causal; const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); @@ -749,8 +747,7 @@ array scaled_dot_product_attention( const bool supports_sdpa_vector = (query_sequence_length <= 8) && (query_sequence_length <= key_sequence_length) && - sdpa_vector_supported_mask && sdpa_vector_supported_head_dim && - stream.device == Device::gpu; + sdpa_vector_supported_head_dim && stream.device == Device::gpu; const bool implementation_supports_use_case = supports_sdpa_full || supports_sdpa_vector; diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index e7b7e5ac3..d35a2b1da 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -352,6 +352,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase): mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + mx.random.uniform(shape=(Nq, 1, L)), + mx.random.uniform(shape=(L, 1, Nq)).T, + mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), + mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), "causal", ] for m in masks: @@ -377,6 +381,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase): mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, + mx.random.uniform(shape=(Nq, 1, L)), + mx.random.uniform(shape=(L, 1, Nq)).T, + mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), + mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), "causal", ] for m in masks: