Add float mask to sdpa vector (#2068)

This commit is contained in:
Angelos Katharopoulos
2025-04-11 17:29:40 -07:00
committed by GitHub
parent 68d1b3256b
commit c4189a38e4
5 changed files with 94 additions and 50 deletions

View File

@@ -163,14 +163,18 @@ void sdpa_vector(
MTL::Size grid_dims(B, q.shape(2), 1);
bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
@@ -194,15 +198,15 @@ void sdpa_vector(
compute_encoder.set_bytes(scale, 10);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11);
compute_encoder.set_input_array(m, 11 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 12);
compute_encoder.set_bytes(q_seq_stride, 13);
compute_encoder.set_bytes(head_stride, 14);
compute_encoder.set_bytes(kv_seq_stride, 13);
compute_encoder.set_bytes(q_seq_stride, 14);
compute_encoder.set_bytes(head_stride, 15);
}
// Launch
@@ -260,14 +264,18 @@ void sdpa_vector_2pass(
d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value();
bool bool_mask = has_mask && (*mask).dtype() == bool_;
bool float_mask = has_mask && !bool_mask;
bool query_transposed = !q.flags().row_contiguous;
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
{&bool_mask, MTL::DataType::DataTypeBool, 23},
{&float_mask, MTL::DataType::DataTypeBool, 24},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
hash_name += query_transposed ? "_qt" : "_qnt";
hash_name += do_causal ? "_c" : "_nc";
@@ -293,15 +301,15 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(scale, 12);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 13);
compute_encoder.set_input_array(m, 13 + float_mask);
auto nd = m.ndim();
int32_t kv_seq_stride =
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
compute_encoder.set_bytes(kv_seq_stride, 14);
compute_encoder.set_bytes(q_seq_stride, 15);
compute_encoder.set_bytes(head_stride, 16);
compute_encoder.set_bytes(kv_seq_stride, 15);
compute_encoder.set_bytes(q_seq_stride, 16);
compute_encoder.set_bytes(head_stride, 17);
}
// Launch