mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add float mask to sdpa vector (#2068)
This commit is contained in:
committed by
GitHub
parent
68d1b3256b
commit
c4189a38e4
@@ -163,14 +163,18 @@ void sdpa_vector(
|
||||
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
|
||||
@@ -194,15 +198,15 @@ void sdpa_vector(
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11);
|
||||
compute_encoder.set_input_array(m, 11 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
compute_encoder.set_bytes(kv_seq_stride, 12);
|
||||
compute_encoder.set_bytes(q_seq_stride, 13);
|
||||
compute_encoder.set_bytes(head_stride, 14);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 13);
|
||||
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||
compute_encoder.set_bytes(head_stride, 15);
|
||||
}
|
||||
|
||||
// Launch
|
||||
@@ -260,14 +264,18 @@ void sdpa_vector_2pass(
|
||||
d.add_temporary(maxs, s.index);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
|
||||
@@ -293,15 +301,15 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_bytes(scale, 12);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 13);
|
||||
compute_encoder.set_input_array(m, 13 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
compute_encoder.set_bytes(kv_seq_stride, 14);
|
||||
compute_encoder.set_bytes(q_seq_stride, 15);
|
||||
compute_encoder.set_bytes(head_stride, 16);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 15);
|
||||
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||
compute_encoder.set_bytes(head_stride, 17);
|
||||
}
|
||||
|
||||
// Launch
|
||||
|
||||
Reference in New Issue
Block a user