causal vector sdpa (#2018)

* causal vector sdpa

* get rid of memory threshold
This commit is contained in:
Awni Hannun 2025-03-28 12:36:13 -07:00 committed by GitHub
parent 98b901ad66
commit 05d7118561
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 24 deletions

View File

@ -6,6 +6,7 @@ using namespace metal;
constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]];
constant bool do_causal [[function_constant(22)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
@ -77,7 +78,13 @@ template <typename T, int D, int V = D>
// For each key
for (int i = simd_gid; i < N; i += BN) {
if (!has_mask || mask[0]) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (has_mask) {
use_key = mask[0];
}
if (use_key) {
// Read the key
for (int j = 0; j < qk_per_thread; j++) {
k[j] = keys[j];
@ -218,7 +225,13 @@ template <typename T, int D, int V = D>
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
if (!has_mask || mask[0]) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
} else if (has_mask) {
use_key = mask[0];
}
if (use_key) {
// Read the key
for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i];

View File

@ -138,6 +138,7 @@ void sdpa_vector(
const array& v,
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
@ -166,6 +167,7 @@ void sdpa_vector(
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
@ -214,6 +216,7 @@ void sdpa_vector_2pass(
const array& v,
array& out,
float scale,
bool do_causal,
const std::optional<array>& mask) {
// Set the kernel name
std::string kname;
@ -260,6 +263,7 @@ void sdpa_vector_2pass(
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
{&query_transposed, MTL::DataType::DataTypeBool, 21},
{&do_causal, MTL::DataType::DataTypeBool, 22},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
@ -401,12 +405,13 @@ void ScaledDotProductAttention::eval_gpu(
// We route to the 2 pass fused attention if
// - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa
bool do_causal = do_causal_ && q.shape(2) > 1;
char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask);
} else {
sdpa_vector(s, d, q, k, v, o, scale_, mask);
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask);
}
}

View File

@ -568,7 +568,6 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
const std::optional<int> memory_efficient_threshold,
StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
@ -654,13 +653,6 @@ array scaled_dot_product_attention(
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
/* Generic implementation for use cases that Metal implementation does not
* support. */
int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
}
auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
@ -725,13 +717,13 @@ array scaled_dot_product_attention(
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
const bool sdpa_vector_supported_mask =
!has_mask || has_bool_mask || do_causal;
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = query_sequence_length >= threshold &&
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
stream.device == Device::gpu;
const bool supports_sdpa_full = sdpa_full_supported_mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&

View File

@ -49,7 +49,6 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::variant<std::monostate, std::string, array>& mask = {},
const std::optional<int> memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(

View File

@ -131,7 +131,6 @@ void init_fast(nb::module_& parent_module) {
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),

View File

@ -95,7 +95,13 @@ def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
p = (q * scale) @ k.transpose(0, 1, 3, 2)
if mask is not None:
if mask.dtype == mx.bool_:
if mask == "causal":
q_offset = max(0, k.shape[2] - q.shape[2])
q_indices = mx.arange(q_offset, q_offset + q.shape[2])
k_indices = mx.arange(k.shape[2])
mask = q_indices[:, None] >= k_indices[None]
p = mx.where(mask, p, mx.finfo(mx.float32).min)
elif mask.dtype == mx.bool_:
p = mx.where(mask, p, mx.finfo(mx.float32).min)
else:
p += mask
@ -176,7 +182,10 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2
q_mlx,
k_mlx,
v_mlx,
scale=scale,
)
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
@ -342,6 +351,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
"causal",
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
@ -366,6 +376,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
"causal",
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
@ -396,6 +407,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
"causal",
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
@ -420,6 +432,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
"causal",
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)