mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add float mask to sdpa vector (#2068)
This commit is contained in:
parent
68d1b3256b
commit
c4189a38e4
@ -1,11 +1,11 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/sdpa_vector.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
// SDPA vector instantiations
|
// SDPA vector instantiations
|
||||||
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
|
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
|
@ -7,6 +7,8 @@ using namespace metal;
|
|||||||
constant bool has_mask [[function_constant(20)]];
|
constant bool has_mask [[function_constant(20)]];
|
||||||
constant bool query_transposed [[function_constant(21)]];
|
constant bool query_transposed [[function_constant(21)]];
|
||||||
constant bool do_causal [[function_constant(22)]];
|
constant bool do_causal [[function_constant(22)]];
|
||||||
|
constant bool bool_mask [[function_constant(23)]];
|
||||||
|
constant bool float_mask [[function_constant(24)]];
|
||||||
|
|
||||||
template <typename T, int D, int V = D>
|
template <typename T, int D, int V = D>
|
||||||
[[kernel]] void sdpa_vector(
|
[[kernel]] void sdpa_vector(
|
||||||
@ -14,17 +16,21 @@ template <typename T, int D, int V = D>
|
|||||||
const device T* keys [[buffer(1)]],
|
const device T* keys [[buffer(1)]],
|
||||||
const device T* values [[buffer(2)]],
|
const device T* values [[buffer(2)]],
|
||||||
device T* out [[buffer(3)]],
|
device T* out [[buffer(3)]],
|
||||||
const constant int& gqa_factor,
|
const constant int& gqa_factor [[buffer(4)]],
|
||||||
const constant int& N,
|
const constant int& N [[buffer(5)]],
|
||||||
const constant size_t& k_head_stride,
|
const constant size_t& k_head_stride [[buffer(6)]],
|
||||||
const constant size_t& k_seq_stride,
|
const constant size_t& k_seq_stride [[buffer(7)]],
|
||||||
const constant size_t& v_head_stride,
|
const constant size_t& v_head_stride [[buffer(8)]],
|
||||||
const constant size_t& v_seq_stride,
|
const constant size_t& v_seq_stride [[buffer(9)]],
|
||||||
const constant float& scale,
|
const constant float& scale [[buffer(10)]],
|
||||||
const device bool* mask [[function_constant(has_mask)]],
|
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
|
||||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
const device T* fmask [[buffer(12), function_constant(float_mask)]],
|
||||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
const constant int& mask_kv_seq_stride
|
||||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
[[buffer(13), function_constant(has_mask)]],
|
||||||
|
const constant int& mask_q_seq_stride
|
||||||
|
[[buffer(14), function_constant(has_mask)]],
|
||||||
|
const constant int& mask_head_stride
|
||||||
|
[[buffer(15), function_constant(has_mask)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 tpg [[threadgroups_per_grid]],
|
uint3 tpg [[threadgroups_per_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
@ -58,8 +64,12 @@ template <typename T, int D, int V = D>
|
|||||||
simd_lid * qk_per_thread;
|
simd_lid * qk_per_thread;
|
||||||
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
||||||
simd_lid * v_per_thread;
|
simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (bool_mask) {
|
||||||
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||||
|
q_seq_idx * mask_q_seq_stride;
|
||||||
|
}
|
||||||
|
if (float_mask) {
|
||||||
|
fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||||
q_seq_idx * mask_q_seq_stride;
|
q_seq_idx * mask_q_seq_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,8 +91,8 @@ template <typename T, int D, int V = D>
|
|||||||
bool use_key = true;
|
bool use_key = true;
|
||||||
if (do_causal) {
|
if (do_causal) {
|
||||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||||
} else if (has_mask) {
|
} else if (bool_mask) {
|
||||||
use_key = mask[0];
|
use_key = bmask[0];
|
||||||
}
|
}
|
||||||
if (use_key) {
|
if (use_key) {
|
||||||
// Read the key
|
// Read the key
|
||||||
@ -96,6 +106,9 @@ template <typename T, int D, int V = D>
|
|||||||
score += q[j] * k[j];
|
score += q[j] * k[j];
|
||||||
}
|
}
|
||||||
score = simd_sum(score);
|
score = simd_sum(score);
|
||||||
|
if (float_mask) {
|
||||||
|
score += max(Limits<U>::finite_min, static_cast<U>(fmask[0]));
|
||||||
|
}
|
||||||
|
|
||||||
// Update the accumulators
|
// Update the accumulators
|
||||||
U new_max = max(max_score, score);
|
U new_max = max(max_score, score);
|
||||||
@ -114,8 +127,11 @@ template <typename T, int D, int V = D>
|
|||||||
// Move the pointers to the next kv
|
// Move the pointers to the next kv
|
||||||
keys += inner_k_stride;
|
keys += inner_k_stride;
|
||||||
values += inner_v_stride;
|
values += inner_v_stride;
|
||||||
if (has_mask) {
|
if (bool_mask) {
|
||||||
mask += BN * mask_kv_seq_stride;
|
bmask += BN * mask_kv_seq_stride;
|
||||||
|
}
|
||||||
|
if (float_mask) {
|
||||||
|
fmask += BN * mask_kv_seq_stride;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,17 +172,21 @@ template <typename T, int D, int V = D>
|
|||||||
device float* out [[buffer(3)]],
|
device float* out [[buffer(3)]],
|
||||||
device float* sums [[buffer(4)]],
|
device float* sums [[buffer(4)]],
|
||||||
device float* maxs [[buffer(5)]],
|
device float* maxs [[buffer(5)]],
|
||||||
const constant int& gqa_factor,
|
const constant int& gqa_factor [[buffer(6)]],
|
||||||
const constant int& N,
|
const constant int& N [[buffer(7)]],
|
||||||
const constant size_t& k_head_stride,
|
const constant size_t& k_head_stride [[buffer(8)]],
|
||||||
const constant size_t& k_seq_stride,
|
const constant size_t& k_seq_stride [[buffer(9)]],
|
||||||
const constant size_t& v_head_stride,
|
const constant size_t& v_head_stride [[buffer(10)]],
|
||||||
const constant size_t& v_seq_stride,
|
const constant size_t& v_seq_stride [[buffer(11)]],
|
||||||
const constant float& scale,
|
const constant float& scale [[buffer(12)]],
|
||||||
const device bool* mask [[function_constant(has_mask)]],
|
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
|
||||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
const device T* fmask [[buffer(14), function_constant(float_mask)]],
|
||||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
const constant int& mask_kv_seq_stride
|
||||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
[[buffer(15), function_constant(has_mask)]],
|
||||||
|
const constant int& mask_q_seq_stride
|
||||||
|
[[buffer(16), function_constant(has_mask)]],
|
||||||
|
const constant int& mask_head_stride
|
||||||
|
[[buffer(17), function_constant(has_mask)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 tpg [[threadgroups_per_grid]],
|
uint3 tpg [[threadgroups_per_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
@ -204,8 +224,13 @@ template <typename T, int D, int V = D>
|
|||||||
values += kv_head_idx * v_head_stride +
|
values += kv_head_idx * v_head_stride +
|
||||||
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
||||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (bool_mask) {
|
||||||
mask += head_idx * mask_head_stride +
|
bmask += head_idx * mask_head_stride +
|
||||||
|
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||||
|
q_seq_idx * mask_q_seq_stride;
|
||||||
|
}
|
||||||
|
if (float_mask) {
|
||||||
|
fmask += head_idx * mask_head_stride +
|
||||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||||
q_seq_idx * mask_q_seq_stride;
|
q_seq_idx * mask_q_seq_stride;
|
||||||
}
|
}
|
||||||
@ -228,8 +253,8 @@ template <typename T, int D, int V = D>
|
|||||||
bool use_key = true;
|
bool use_key = true;
|
||||||
if (do_causal) {
|
if (do_causal) {
|
||||||
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
||||||
} else if (has_mask) {
|
} else if (bool_mask) {
|
||||||
use_key = mask[0];
|
use_key = bmask[0];
|
||||||
}
|
}
|
||||||
if (use_key) {
|
if (use_key) {
|
||||||
// Read the key
|
// Read the key
|
||||||
@ -243,6 +268,9 @@ template <typename T, int D, int V = D>
|
|||||||
score += q[i] * k[i];
|
score += q[i] * k[i];
|
||||||
}
|
}
|
||||||
score = simd_sum(score);
|
score = simd_sum(score);
|
||||||
|
if (float_mask) {
|
||||||
|
score += fmask[0];
|
||||||
|
}
|
||||||
|
|
||||||
// Update the accumulators
|
// Update the accumulators
|
||||||
U new_max = max(max_score, score);
|
U new_max = max(max_score, score);
|
||||||
@ -261,8 +289,11 @@ template <typename T, int D, int V = D>
|
|||||||
// Move the pointers to the next kv
|
// Move the pointers to the next kv
|
||||||
keys += blocks * inner_k_stride;
|
keys += blocks * inner_k_stride;
|
||||||
values += blocks * inner_v_stride;
|
values += blocks * inner_v_stride;
|
||||||
if (has_mask) {
|
if (bool_mask) {
|
||||||
mask += BN * blocks * mask_kv_seq_stride;
|
bmask += BN * blocks * mask_kv_seq_stride;
|
||||||
|
}
|
||||||
|
if (float_mask) {
|
||||||
|
fmask += BN * blocks * mask_kv_seq_stride;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,14 +163,18 @@ void sdpa_vector(
|
|||||||
MTL::Size grid_dims(B, q.shape(2), 1);
|
MTL::Size grid_dims(B, q.shape(2), 1);
|
||||||
|
|
||||||
bool has_mask = mask.has_value();
|
bool has_mask = mask.has_value();
|
||||||
|
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||||
|
bool float_mask = has_mask && !bool_mask;
|
||||||
bool query_transposed = !q.flags().row_contiguous;
|
bool query_transposed = !q.flags().row_contiguous;
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||||
|
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||||
|
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||||
};
|
};
|
||||||
std::string hash_name = kname;
|
std::string hash_name = kname;
|
||||||
hash_name += has_mask ? "_mask" : "_nomask";
|
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||||
hash_name += do_causal ? "_c" : "_nc";
|
hash_name += do_causal ? "_c" : "_nc";
|
||||||
|
|
||||||
@ -194,15 +198,15 @@ void sdpa_vector(
|
|||||||
compute_encoder.set_bytes(scale, 10);
|
compute_encoder.set_bytes(scale, 10);
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 11);
|
compute_encoder.set_input_array(m, 11 + float_mask);
|
||||||
auto nd = m.ndim();
|
auto nd = m.ndim();
|
||||||
int32_t kv_seq_stride =
|
int32_t kv_seq_stride =
|
||||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||||
compute_encoder.set_bytes(kv_seq_stride, 12);
|
compute_encoder.set_bytes(kv_seq_stride, 13);
|
||||||
compute_encoder.set_bytes(q_seq_stride, 13);
|
compute_encoder.set_bytes(q_seq_stride, 14);
|
||||||
compute_encoder.set_bytes(head_stride, 14);
|
compute_encoder.set_bytes(head_stride, 15);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
@ -260,14 +264,18 @@ void sdpa_vector_2pass(
|
|||||||
d.add_temporary(maxs, s.index);
|
d.add_temporary(maxs, s.index);
|
||||||
|
|
||||||
bool has_mask = mask.has_value();
|
bool has_mask = mask.has_value();
|
||||||
|
bool bool_mask = has_mask && (*mask).dtype() == bool_;
|
||||||
|
bool float_mask = has_mask && !bool_mask;
|
||||||
bool query_transposed = !q.flags().row_contiguous;
|
bool query_transposed = !q.flags().row_contiguous;
|
||||||
metal::MTLFCList func_consts = {
|
metal::MTLFCList func_consts = {
|
||||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||||
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
{&query_transposed, MTL::DataType::DataTypeBool, 21},
|
||||||
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
{&do_causal, MTL::DataType::DataTypeBool, 22},
|
||||||
|
{&bool_mask, MTL::DataType::DataTypeBool, 23},
|
||||||
|
{&float_mask, MTL::DataType::DataTypeBool, 24},
|
||||||
};
|
};
|
||||||
std::string hash_name = kname;
|
std::string hash_name = kname;
|
||||||
hash_name += has_mask ? "_mask" : "_nomask";
|
hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask";
|
||||||
hash_name += query_transposed ? "_qt" : "_qnt";
|
hash_name += query_transposed ? "_qt" : "_qnt";
|
||||||
hash_name += do_causal ? "_c" : "_nc";
|
hash_name += do_causal ? "_c" : "_nc";
|
||||||
|
|
||||||
@ -293,15 +301,15 @@ void sdpa_vector_2pass(
|
|||||||
compute_encoder.set_bytes(scale, 12);
|
compute_encoder.set_bytes(scale, 12);
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
auto& m = *mask;
|
auto& m = *mask;
|
||||||
compute_encoder.set_input_array(m, 13);
|
compute_encoder.set_input_array(m, 13 + float_mask);
|
||||||
auto nd = m.ndim();
|
auto nd = m.ndim();
|
||||||
int32_t kv_seq_stride =
|
int32_t kv_seq_stride =
|
||||||
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
nd >= 1 && m.shape(-1) > 1 ? m.strides()[nd - 1] : 0;
|
||||||
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
int32_t q_seq_stride = nd >= 2 && m.shape(-2) > 1 ? m.strides()[nd - 2] : 0;
|
||||||
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
int32_t head_stride = nd >= 3 && m.shape(-3) > 1 ? m.strides()[nd - 3] : 0;
|
||||||
compute_encoder.set_bytes(kv_seq_stride, 14);
|
compute_encoder.set_bytes(kv_seq_stride, 15);
|
||||||
compute_encoder.set_bytes(q_seq_stride, 15);
|
compute_encoder.set_bytes(q_seq_stride, 16);
|
||||||
compute_encoder.set_bytes(head_stride, 16);
|
compute_encoder.set_bytes(head_stride, 17);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
|
@ -739,8 +739,6 @@ array scaled_dot_product_attention(
|
|||||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||||
|
|
||||||
const bool sdpa_vector_supported_mask =
|
|
||||||
!has_mask || has_bool_mask || do_causal;
|
|
||||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||||
(query_sequence_length <= key_sequence_length && do_causal);
|
(query_sequence_length <= key_sequence_length && do_causal);
|
||||||
|
|
||||||
@ -749,8 +747,7 @@ array scaled_dot_product_attention(
|
|||||||
|
|
||||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||||
(query_sequence_length <= key_sequence_length) &&
|
(query_sequence_length <= key_sequence_length) &&
|
||||||
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
|
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
|
||||||
stream.device == Device::gpu;
|
|
||||||
|
|
||||||
const bool implementation_supports_use_case =
|
const bool implementation_supports_use_case =
|
||||||
supports_sdpa_full || supports_sdpa_vector;
|
supports_sdpa_full || supports_sdpa_vector;
|
||||||
|
@ -352,6 +352,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
mx.array([True] * (L - 10) + [False] * 10),
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||||
|
mx.random.uniform(shape=(Nq, 1, L)),
|
||||||
|
mx.random.uniform(shape=(L, 1, Nq)).T,
|
||||||
|
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
|
||||||
|
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
|
||||||
"causal",
|
"causal",
|
||||||
]
|
]
|
||||||
for m in masks:
|
for m in masks:
|
||||||
@ -377,6 +381,10 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
mx.array([True] * (L - 10) + [False] * 10),
|
mx.array([True] * (L - 10) + [False] * 10),
|
||||||
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
|
||||||
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
|
||||||
|
mx.random.uniform(shape=(Nq, 1, L)),
|
||||||
|
mx.random.uniform(shape=(L, 1, Nq)).T,
|
||||||
|
mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2),
|
||||||
|
mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2),
|
||||||
"causal",
|
"causal",
|
||||||
]
|
]
|
||||||
for m in masks:
|
for m in masks:
|
||||||
|
Loading…
Reference in New Issue
Block a user