Add float mask to sdpa vector (#2068)

This commit is contained in:
Angelos Katharopoulos 2025-04-11 17:29:40 -07:00 committed by GitHub
parent 68d1b3256b
commit c4189a38e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 94 additions and 50 deletions

View File

@ -1,11 +1,11 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/sdpa_vector.h" // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
using namespace metal; using namespace metal;
// clang-format off
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector_aggregation(type, value_dim) \ #define instantiate_sdpa_vector_aggregation(type, value_dim) \
instantiate_kernel( \ instantiate_kernel( \

View File

@ -7,6 +7,8 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]]; constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]]; constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]]; constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
template <typename T, int D, int V = D> template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
@ -14,17 +16,21 @@ template <typename T, int D, int V = D>
const device T* keys [[buffer(1)]], const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]], const device T* values [[buffer(2)]],
device T* out [[buffer(3)]], device T* out [[buffer(3)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(4)]],
const constant int& N, const constant int& N [[buffer(5)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(6)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(7)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(8)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(9)]],
const constant float& scale, const constant float& scale [[buffer(10)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(12), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(13), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(14), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(15), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -58,8 +64,12 @@ template <typename T, int D, int V = D>
simd_lid * qk_per_thread; simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
simd_lid * v_per_thread; simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@ -81,8 +91,8 @@ template <typename T, int D, int V = D>
bool use_key = true; bool use_key = true;
if (do_causal) { if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (has_mask) { } else if (bool_mask) {
use_key = mask[0]; use_key = bmask[0];
} }
if (use_key) { if (use_key) {
// Read the key // Read the key
@ -96,6 +106,9 @@ template <typename T, int D, int V = D>
score += q[j] * k[j]; score += q[j] * k[j];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@ -114,8 +127,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += inner_k_stride; keys += inner_k_stride;
values += inner_v_stride; values += inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * mask_kv_seq_stride; bmask += BN * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * mask_kv_seq_stride;
} }
} }
@ -156,17 +172,21 @@ template <typename T, int D, int V = D>
device float* out [[buffer(3)]], device float* out [[buffer(3)]],
device float* sums [[buffer(4)]], device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]], device float* maxs [[buffer(5)]],
const constant int& gqa_factor, const constant int& gqa_factor [[buffer(6)]],
const constant int& N, const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride, const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride, const constant size_t& k_seq_stride [[buffer(9)]],
const constant size_t& v_head_stride, const constant size_t& v_head_stride [[buffer(10)]],
const constant size_t& v_seq_stride, const constant size_t& v_seq_stride [[buffer(11)]],
const constant float& scale, const constant float& scale [[buffer(12)]],
const device bool* mask [[function_constant(has_mask)]], const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]], const device T* fmask [[buffer(14), function_constant(float_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]], const constant int& mask_kv_seq_stride
const constant int& mask_head_stride [[function_constant(has_mask)]], [[buffer(15), function_constant(has_mask)]],
const constant int& mask_q_seq_stride
[[buffer(16), function_constant(has_mask)]],
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]], uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@ -204,8 +224,13 @@ template <typename T, int D, int V = D>
values += kv_head_idx * v_head_stride + values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) { if (bool_mask) {
mask += head_idx * mask_head_stride + bmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride; q_seq_idx * mask_q_seq_stride;
} }
@ -228,8 +253,8 @@ template <typename T, int D, int V = D>
bool use_key = true; bool use_key = true;
if (do_causal) { if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (has_mask) { } else if (bool_mask) {
use_key = mask[0]; use_key = bmask[0];
} }
if (use_key) { if (use_key) {
// Read the key // Read the key
@ -243,6 +268,9 @@ template <typename T, int D, int V = D>
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
if (float_mask) {
score += fmask[0];
}
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@ -261,8 +289,11 @@ template <typename T, int D, int V = D>
// Move the pointers to the next kv // Move the pointers to the next kv
keys += blocks * inner_k_stride; keys += blocks * inner_k_stride;
values += blocks * inner_v_stride; values += blocks * inner_v_stride;
if (has_mask) { if (bool_mask) {
mask += BN * blocks * mask_kv_seq_stride; bmask += BN * blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
} }
} }

View File

@ -163,14 +163,18 @@ void sdpa_vector(
MTL::Size grid_dims(B, q.shape(2), 1); MTL::Size grid_dims(B, q.shape(2), 1);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22}, {&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc"; hash_name += do_causal ? "_c" : "_nc";
@ -194,15 +198,15 @@ void sdpa_vector(
compute_encoder.set_bytes(scale, 10); compute_encoder.set_bytes(scale, 10);
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 11); compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim(); auto nd = m.ndim();
int32_t kv_seq_stride = int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 12); compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 13); compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 14); compute_encoder.set_bytes(head_stride, 15);
} }
// Launch // Launch
@ -260,14 +264,18 @@ void sdpa_vector_2pass(
d.add_temporary(maxs, s.index); d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value(); bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous; bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = { metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20}, {&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21}, {&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22}, {&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
}; };
std::string hash_name = kname; std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask"; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc"; hash_name += do_causal ? "_c" : "_nc";
@ -293,15 +301,15 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(scale, 12); compute_encoder.set_bytes(scale, 12);
if (has_mask) { if (has_mask) {
auto& m = *mask; auto& m = *mask;
compute_encoder.set_input_array(m, 13); compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim(); auto nd = m.ndim();
int32_t kv_seq_stride = int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0; nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0; int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0; int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 14); compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 15); compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 16); compute_encoder.set_bytes(head_stride, 17);
} }
// Launch // Launch

View File

@ -739,8 +739,6 @@ array scaled_dot_product_attention(
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_vector_supported_mask =
!has_mask || has_bool_mask || do_causal;
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal); (query_sequence_length <= key_sequence_length && do_causal);
@ -749,8 +747,7 @@ array scaled_dot_product_attention(
const bool supports_sdpa_vector = (query_sequence_length <= 8) && const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) && (query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim && sdpa_vector_supported_head_dim && stream.device == Device::gpu;
stream.device == Device::gpu;
const bool implementation_supports_use_case = const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector; supports_sdpa_full || supports_sdpa_vector;

View File

@ -352,6 +352,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10), mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
mx.random.uniform(shape=(Nq, 1, L)),
mx.random.uniform(shape=(L, 1, Nq)).T,
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
"causal", "causal",
] ]
for m in masks: for m in masks:
@ -377,6 +381,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10), mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
mx.random.uniform(shape=(Nq, 1, L)),
mx.random.uniform(shape=(L, 1, Nq)).T,
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
"causal", "causal",
] ]
for m in masks: for m in masks: