mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add float mask to sdpa vector (#2068)
This commit is contained in:
parent
68d1b3256b
commit
c4189a38e4
@ -1,11 +1,11 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// clang-format off
|
||||
// SDPA vector instantiations
|
||||
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
|
||||
instantiate_kernel( \
|
||||
|
@ -7,6 +7,8 @@ using namespace metal;
|
||||
constant bool has_mask [[function_constant(20)]];
|
||||
constant bool query_transposed [[function_constant(21)]];
|
||||
constant bool do_causal [[function_constant(22)]];
|
||||
constant bool bool_mask [[function_constant(23)]];
|
||||
constant bool float_mask [[function_constant(24)]];
|
||||
|
||||
template <typename T, int D, int V = D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
@ -14,17 +16,21 @@ template <typename T, int D, int V = D>
|
||||
const device T* keys [[buffer(1)]],
|
||||
const device T* values [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_head_stride,
|
||||
const constant size_t& k_seq_stride,
|
||||
const constant size_t& v_head_stride,
|
||||
const constant size_t& v_seq_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
const constant int& gqa_factor [[buffer(4)]],
|
||||
const constant int& N [[buffer(5)]],
|
||||
const constant size_t& k_head_stride [[buffer(6)]],
|
||||
const constant size_t& k_seq_stride [[buffer(7)]],
|
||||
const constant size_t& v_head_stride [[buffer(8)]],
|
||||
const constant size_t& v_seq_stride [[buffer(9)]],
|
||||
const constant float& scale [[buffer(10)]],
|
||||
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
|
||||
const device T* fmask [[buffer(12), function_constant(float_mask)]],
|
||||
const constant int& mask_kv_seq_stride
|
||||
[[buffer(13), function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride
|
||||
[[buffer(14), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(15), function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -58,8 +64,12 @@ template <typename T, int D, int V = D>
|
||||
simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
||||
simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
|
||||
@ -81,8 +91,8 @@ template <typename T, int D, int V = D>
|
||||
bool use_key = true;
|
||||
if (do_causal) {
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (has_mask) {
|
||||
use_key = mask[0];
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
@ -96,6 +106,9 @@ template <typename T, int D, int V = D>
|
||||
score += q[j] * k[j];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
|
||||
}
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
@ -114,8 +127,11 @@ template <typename T, int D, int V = D>
|
||||
// Move the pointers to the next kv
|
||||
keys += inner_k_stride;
|
||||
values += inner_v_stride;
|
||||
if (has_mask) {
|
||||
mask += BN * mask_kv_seq_stride;
|
||||
if (bool_mask) {
|
||||
bmask += BN * mask_kv_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += BN * mask_kv_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,17 +172,21 @@ template <typename T, int D, int V = D>
|
||||
device float* out [[buffer(3)]],
|
||||
device float* sums [[buffer(4)]],
|
||||
device float* maxs [[buffer(5)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_head_stride,
|
||||
const constant size_t& k_seq_stride,
|
||||
const constant size_t& v_head_stride,
|
||||
const constant size_t& v_seq_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
const constant int& gqa_factor [[buffer(6)]],
|
||||
const constant int& N [[buffer(7)]],
|
||||
const constant size_t& k_head_stride [[buffer(8)]],
|
||||
const constant size_t& k_seq_stride [[buffer(9)]],
|
||||
const constant size_t& v_head_stride [[buffer(10)]],
|
||||
const constant size_t& v_seq_stride [[buffer(11)]],
|
||||
const constant float& scale [[buffer(12)]],
|
||||
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
|
||||
const device T* fmask [[buffer(14), function_constant(float_mask)]],
|
||||
const constant int& mask_kv_seq_stride
|
||||
[[buffer(15), function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride
|
||||
[[buffer(16), function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride
|
||||
[[buffer(17), function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
@ -204,8 +224,13 @@ template <typename T, int D, int V = D>
|
||||
values += kv_head_idx * v_head_stride +
|
||||
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride +
|
||||
if (bool_mask) {
|
||||
bmask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
@ -228,8 +253,8 @@ template <typename T, int D, int V = D>
|
||||
bool use_key = true;
|
||||
if (do_causal) {
|
||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||
} else if (has_mask) {
|
||||
use_key = mask[0];
|
||||
} else if (bool_mask) {
|
||||
use_key = bmask[0];
|
||||
}
|
||||
if (use_key) {
|
||||
// Read the key
|
||||
@ -243,6 +268,9 @@ template <typename T, int D, int V = D>
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
if (float_mask) {
|
||||
score += fmask[0];
|
||||
}
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
@ -261,8 +289,11 @@ template <typename T, int D, int V = D>
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * inner_k_stride;
|
||||
values += blocks * inner_v_stride;
|
||||
if (has_mask) {
|
||||
mask += BN * blocks * mask_kv_seq_stride;
|
||||
if (bool_mask) {
|
||||
bmask += BN * blocks * mask_kv_seq_stride;
|
||||
}
|
||||
if (float_mask) {
|
||||
fmask += BN * blocks * mask_kv_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,14 +163,18 @@ void sdpa_vector(
|
||||
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
|
||||
@ -194,15 +198,15 @@ void sdpa_vector(
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11);
|
||||
compute_encoder.set_input_array(m, 11 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
compute_encoder.set_bytes(kv_seq_stride, 12);
|
||||
compute_encoder.set_bytes(q_seq_stride, 13);
|
||||
compute_encoder.set_bytes(head_stride, 14);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 13);
|
||||
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||
compute_encoder.set_bytes(head_stride, 15);
|
||||
}
|
||||
|
||||
// Launch
|
||||
@ -260,14 +264,18 @@ void sdpa_vector_2pass(
|
||||
d.add_temporary(maxs, s.index);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||
bool float_mask = has_mask && !bool_mask;
|
||||
bool query_transposed = !q.flags().row_contiguous;
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||
hash_name += do_causal ? "_c" : "_nc";
|
||||
|
||||
@ -293,15 +301,15 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_bytes(scale, 12);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 13);
|
||||
compute_encoder.set_input_array(m, 13 + float_mask);
|
||||
auto nd = m.ndim();
|
||||
int32_t kv_seq_stride =
|
||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||
compute_encoder.set_bytes(kv_seq_stride, 14);
|
||||
compute_encoder.set_bytes(q_seq_stride, 15);
|
||||
compute_encoder.set_bytes(head_stride, 16);
|
||||
compute_encoder.set_bytes(kv_seq_stride, 15);
|
||||
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||
compute_encoder.set_bytes(head_stride, 17);
|
||||
}
|
||||
|
||||
// Launch
|
||||
|
@ -739,8 +739,6 @@ array scaled_dot_product_attention(
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool sdpa_vector_supported_mask =
|
||||
!has_mask || has_bool_mask || do_causal;
|
||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||
(query_sequence_length <= key_sequence_length && do_causal);
|
||||
|
||||
@ -749,8 +747,7 @@ array scaled_dot_product_attention(
|
||||
|
||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||
(query_sequence_length <= key_sequence_length) &&
|
||||
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
|
||||
stream.device == Device::gpu;
|
||||
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
|
||||
|
||||
const bool implementation_supports_use_case =
|
||||
supports_sdpa_full || supports_sdpa_vector;
|
||||
|
@ -352,6 +352,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
mx.random.uniform(shape=(Nq, 1, L)),
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T,
|
||||
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
|
||||
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
@ -377,6 +381,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
mx.array([True] * (L - 10) + [False] * 10),
|
||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||
mx.random.uniform(shape=(Nq, 1, L)),
|
||||
mx.random.uniform(shape=(L, 1, Nq)).T,
|
||||
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
|
||||
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
|
||||
"causal",
|
||||
]
|
||||
for m in masks:
|
||||
|
Loading…
Reference in New Issue
Block a user