Add boolean mask support in vector SDPA (#1757)

This commit is contained in:
Awni Hannun 2025-01-07 20:24:53 -08:00 committed by GitHub
parent 516ded618b
commit d1766f2c70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 226 additions and 74 deletions

View File

@ -12,7 +12,7 @@ dtype = mx.float16
loops = 10 loops = 10
def attention(q, k, v): def attention(q, k, v, mask=None):
def _sdpa(q, k, v): def _sdpa(q, k, v):
B, Hq, L, D = q.shape B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape _, Hk, S, _ = k.shape
@ -20,6 +20,9 @@ def attention(q, k, v):
k = k[:, :, None, :, :] k = k[:, :, None, :, :]
v = v[:, :, None, :, :] v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3) s = q @ k.transpose(0, 1, 2, 4, 3)
if mask is not None:
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v o = p @ v
return o.reshape(B, Hq, L, D) return o.reshape(B, Hq, L, D)
@ -29,9 +32,9 @@ def attention(q, k, v):
return q return q
def sdpa(q, k, v): def sdpa(q, k, v, mask=None):
for i in range(loops): for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
return q return q
@ -53,6 +56,26 @@ def time_self_attention_sdpa():
time_fn(sdpa, q, k, v) time_fn(sdpa, q, k, v)
def time_self_attention_sdpa_with_mask():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask)
def sdpa_mask(*args):
return sdpa(*args, mask=mask)
def attention_mask(*args):
return attention(*args, mask=mask)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v)
if __name__ == "__main__": if __name__ == "__main__":
time_self_attention_sdpa() time_self_attention_sdpa()
time_self_attention_primitives() time_self_attention_primitives()
time_self_attention_sdpa_with_mask()

View File

@ -4,6 +4,8 @@
using namespace metal; using namespace metal;
constant bool has_mask [[function_constant(20)]];
template <typename T, int D> template <typename T, int D>
[[kernel]] void sdpa_vector( [[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]], const device T* queries [[buffer(0)]],
@ -15,6 +17,9 @@ template <typename T, int D>
const constant size_t& k_stride, const constant size_t& k_stride,
const constant size_t& v_stride, const constant size_t& v_stride,
const constant float& scale, const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@ -39,6 +44,9 @@ template <typename T, int D>
queries += head_idx * D + simd_lid * elem_per_thread; queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
}
out += head_idx * D + simd_gid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator // Read the query and 0 the output accumulator
@ -54,34 +62,39 @@ template <typename T, int D>
// For each key // For each key
for (int i = simd_gid; i < N; i += BN) { for (int i = simd_gid; i < N; i += BN) {
// Read the key if (!has_mask || mask[0]) {
for (int i = 0; i < elem_per_thread; i++) { // Read the key
k[i] = keys[i]; for (int j = 0; j < elem_per_thread; j++) {
} k[j] = keys[j];
}
// Compute the i-th score // Compute the i-th score
U score = 0; U score = 0;
for (int i = 0; i < elem_per_thread; i++) { for (int j = 0; j < elem_per_thread; j++) {
score += q[i] * k[i]; score += q[j] * k[j];
} }
score = simd_sum(score); score = simd_sum(score);
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max); U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max); U exp_score = fast::exp(score - new_max);
max_score = new_max; max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator // Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) { for (int j = 0; j < elem_per_thread; j++) {
o[i] = o[i] * factor + exp_score * values[i]; o[j] = o[j] * factor + exp_score * values[j];
}
} }
// Move the pointers to the next kv // Move the pointers to the next kv
keys += stride; keys += stride;
values += stride; values += stride;
if (has_mask) {
mask += BN * mask_seq_stride;
}
} }
// Each thread has a partial part of the output so we need to combine them. // Each thread has a partial part of the output so we need to combine them.
@ -126,6 +139,9 @@ template <typename T, int D>
const constant size_t& k_stride, const constant size_t& k_stride,
const constant size_t& v_stride, const constant size_t& v_stride,
const constant float& scale, const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
@ -155,6 +171,10 @@ template <typename T, int D>
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread; simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride;
}
sums += head_idx * blocks + block_idx; sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx; maxs += head_idx * blocks + block_idx;
@ -171,34 +191,39 @@ template <typename T, int D>
// For each key // For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key if (!has_mask || mask[0]) {
for (int i = 0; i < elem_per_thread; i++) { // Read the key
k[i] = keys[i]; for (int i = 0; i < elem_per_thread; i++) {
} k[i] = keys[i];
}
// Compute the i-th score // Compute the i-th score
U score = 0; U score = 0;
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = simd_sum(score);
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max); U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max); U exp_score = fast::exp(score - new_max);
max_score = new_max; max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score; sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator // Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i]; o[i] = o[i] * factor + exp_score * values[i];
}
} }
// Move the pointers to the next kv // Move the pointers to the next kv
keys += blocks * stride; keys += blocks * stride;
values += blocks * stride; values += blocks * stride;
if (has_mask) {
mask += BN * blocks * mask_seq_stride;
}
} }
// Each thread has a partial part of the output so we need to combine them. // Each thread has a partial part of the output so we need to combine them.

View File

@ -1,6 +1,5 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert>
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
@ -116,7 +115,8 @@ void sdpa_vector(
const array& k, const array& k,
const array& v, const array& v,
array& out, array& out,
float scale) { float scale,
const std::optional<array>& mask) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
kname.reserve(64); kname.reserve(64);
@ -134,9 +134,16 @@ void sdpa_vector(
MTL::Size group_dims(1024, 1, 1); MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1); MTL::Size grid_dims(1, B, 1);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname); auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments // Set its arguments
@ -149,6 +156,14 @@ void sdpa_vector(
compute_encoder.set_bytes(k_stride, 6); compute_encoder.set_bytes(k_stride, 6);
compute_encoder.set_bytes(v_stride, 7); compute_encoder.set_bytes(v_stride, 7);
compute_encoder.set_bytes(scale, 8); compute_encoder.set_bytes(scale, 8);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 9);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 10);
compute_encoder.set_bytes(head_stride, 11);
}
// Launch // Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@ -161,7 +176,8 @@ void sdpa_vector_2pass(
const array& k, const array& k,
const array& v, const array& v,
array& out, array& out,
float scale) { float scale,
const std::optional<array>& mask) {
// Set the kernel name // Set the kernel name
std::string kname; std::string kname;
kname.reserve(64); kname.reserve(64);
@ -198,9 +214,17 @@ void sdpa_vector_2pass(
d.add_temporary(sums, s.index); d.add_temporary(sums, s.index);
d.add_temporary(maxs, s.index); d.add_temporary(maxs, s.index);
bool has_mask = mask.has_value();
metal::MTLFCList func_consts = {
{&has_mask, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
hash_name += has_mask ? "_mask" : "_nomask";
// Get the kernel // Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname); auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments // Set its arguments
@ -215,6 +239,14 @@ void sdpa_vector_2pass(
compute_encoder.set_bytes(k_stride, 8); compute_encoder.set_bytes(k_stride, 8);
compute_encoder.set_bytes(v_stride, 9); compute_encoder.set_bytes(v_stride, 9);
compute_encoder.set_bytes(scale, 10); compute_encoder.set_bytes(scale, 10);
if (has_mask) {
auto& m = *mask;
compute_encoder.set_input_array(m, 11);
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
compute_encoder.set_bytes(seq_stride, 12);
compute_encoder.set_bytes(head_stride, 13);
}
// Launch // Launch
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@ -247,8 +279,6 @@ void sdpa_vector_2pass(
void ScaledDotProductAttention::eval_gpu( void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {
assert(inputs.size() == 3);
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu(
// We are in vector mode ie single query // We are in vector mode ie single query
if (q_pre.shape(2) == 1) { if (q_pre.shape(2) == 1) {
const auto& q = copy_unless(is_contiguous, q_pre); const auto& q = copy_unless(is_contiguous, q_pre);
// 1, heads, seq_len, head_dim
// mask [1, query_heads, 1, seq_len]
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre); const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre); const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(allocator::malloc_or_wait(o.nbytes())); o.set_data(allocator::malloc_or_wait(o.nbytes()));
} }
auto mask =
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
// We route to the 2 pass fused attention if // We route to the 2 pass fused attention if
// - The device is large and the sequence length long // - The device is large and the sequence length long
// - The sequence length is even longer and we have gqa // - The sequence length is even longer and we have gqa
char devc = d.get_architecture().back(); char devc = d.get_architecture().back();
if ((devc == 'd' && k.shape(2) >= 1024) || if ((devc == 'd' && k.shape(2) >= 1024) ||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
sdpa_vector_2pass(s, d, q, k, v, o, scale_); sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
} else { } else {
sdpa_vector(s, d, q, k, v, o, scale_); sdpa_vector(s, d, q, k, v, o, scale_, mask);
} }
} }

View File

@ -609,27 +609,32 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (mask && promote_types((*mask).dtype(), final_type) != final_type) { if (mask) {
std::ostringstream msg; // Check type
msg << "[scaled_dot_product_attention] Mask type must promote to output type. " if (promote_types(mask->dtype(), final_type) != final_type) {
<< final_type << "."; std::ostringstream msg;
throw std::invalid_argument(msg.str()); msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
// Check shape
auto mask_shape = queries.shape();
mask_shape.back() = keys.shape(-2);
if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
<< " does not broadcast to implicit scores with shape " << mask_shape
<< ".";
throw std::invalid_argument(msg.str());
}
} }
auto q = astype(queries, final_type, s); auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s); auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s); auto v = astype(values, final_type, s);
/* generic implementation for use cases that Metal implementation does not /* Generic implementation for use cases that Metal implementation does not
* support. For non-supported cases listed below, use MLX primitives: * support. */
* * CPU implementation
* * batch size > 1 for decoding or causal attention
* * query sequence length > 1 for decoding
* * query sequence length > 16 && non-null mask (causal attention)
* * non-null mask
* * dtype is not fp32 or fp16
*/
int threshold = 32; // TODO: Fix after dev int threshold = 32; // TODO: Fix after dev
if (memory_efficient_threshold.has_value()) { if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value()); threshold = std::max(1, memory_efficient_threshold.value());
@ -690,27 +695,27 @@ array scaled_dot_product_attention(
!mask.has_value() && sdpa_full_supported_head_dim && !mask.has_value() && sdpa_full_supported_head_dim &&
stream.device == Device::gpu; stream.device == Device::gpu;
const bool supported_mask = !mask || (mask->dtype() == bool_);
const bool supports_sdpa_vector = query_sequence_length == 1 && const bool supports_sdpa_vector = query_sequence_length == 1 &&
!mask.has_value() && sdpa_vector_supported_head_dim && supported_mask && sdpa_vector_supported_head_dim &&
stream.device == Device::gpu; stream.device == Device::gpu;
implementation_supports_use_case &= implementation_supports_use_case &=
supports_sdpa_full || supports_sdpa_vector; supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v};
if (mask) {
inputs.push_back(*mask);
}
if (implementation_supports_use_case) { if (implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array( return array(
std::move(out_shape), std::move(out_shape),
final_type, final_type,
std::make_shared<ScaledDotProductAttention>(stream, fallback, scale), std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
{q, k, v}); std::move(inputs));
}
if (mask.has_value()) {
return fallback({q, k, v, mask.value()})[0];
} else {
return fallback({q, k, v})[0];
} }
return fallback(inputs)[0];
} }
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {

View File

@ -10,7 +10,10 @@ import numpy as np
def mlx_primitives_sdpa(q, k, v, scale, mask=None): def mlx_primitives_sdpa(q, k, v, scale, mask=None):
p = (q * scale) @ k.transpose(0, 1, 3, 2) p = (q * scale) @ k.transpose(0, 1, 3, 2)
if mask is not None: if mask is not None:
p += mask if mask.dtype == mx.bool_:
p = mx.where(mask, p, mx.finfo(mx.float32).min)
else:
p += mask
scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype) scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype)
return scores @ v return scores @ v
@ -198,6 +201,67 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(y, y_hat, atol=atol)) self.assertTrue(mx.allclose(y, y_hat, atol=atol))
def test_fast_sdpa_vector(self):
D = 64
L = 43
Nq = 4
Nkv = 1
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
with self.assertRaises(ValueError):
mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=mx.full((Nq, 2, L), False),
)
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
L = 4096
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(failfast=True) unittest.main(failfast=True)