[CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Cheng
2025-11-26 11:08:58 +09:00
committed by GitHub
parent c9f4dc851f
commit 704fd1ae28
6 changed files with 146 additions and 47 deletions

View File

@@ -107,6 +107,8 @@ struct SDPACacheKey {
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
std::array<int, QKV_NDIM> mask_shape;
std::array<int64_t, QKV_NDIM> mask_strides;
bool output_logsumexp;
};
@@ -116,6 +118,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp = true) {
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
@@ -128,8 +131,14 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
{},
{},
output_logsumexp,
};
if (mask_arr) {
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
}
return cache_key;
}
@@ -150,6 +159,7 @@ enum UIDS {
K,
V,
SCALE,
BIAS,
O,
STATS,
// Backward graph:
@@ -165,6 +175,7 @@ fe::graph::Graph build_sdpa_graph(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
const array& o,
const array& stats) {
@@ -198,6 +209,11 @@ fe::graph::Graph build_sdpa_graph(
.set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(output_logsumexp);
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
}
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
o_->set_output(true);
@@ -224,6 +240,7 @@ fe::graph::Graph build_sdpa_backward_graph(
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
const array& o,
const array& d_o,
const array& stats,
@@ -266,6 +283,11 @@ fe::graph::Graph build_sdpa_backward_graph(
.set_name("sdpa_backward_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal);
if (mask_arr) {
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
set_tensor_attrs(bias_, BIAS, *mask_arr);
options.set_bias(bias_);
}
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
@@ -318,8 +340,6 @@ bool supports_sdpa_cudnn(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool do_causal,
Stream s) {
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
if (!enabled) {
@@ -331,17 +351,6 @@ bool supports_sdpa_cudnn(
return false;
}
if (has_mask) {
// TODO: Support array masks.
if (!do_causal) {
return false;
}
// FIXME: Causal mask generates wrong results when L_Q != L_K.
if (q.shape(2) != k.shape(2)) {
return false;
}
}
// Only use cuDNN for prefilling and training.
if (q.shape(2) != k.shape(2)) {
return false;
@@ -365,6 +374,7 @@ void sdpa_cudnn(
array& o,
array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
@@ -376,19 +386,21 @@ void sdpa_cudnn(
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
if (output_logsumexp) {
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
encoder.set_output_array(stats);
}
// Search cache.
auto cache_key =
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
auto cache_key = build_sdpa_cache_key(
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
auto graph = build_sdpa_graph(
handle, q, k, v, do_causal, output_logsumexp, o, stats);
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
@@ -399,6 +411,9 @@ void sdpa_cudnn(
{V, const_cast<void*>(gpu_ptr<void>(v))},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
}
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
@@ -414,6 +429,7 @@ void sdpa_backward_cudnn(
const array& o,
const array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
const array& d_o,
array& d_q,
array& d_k,
@@ -435,13 +451,16 @@ void sdpa_backward_cudnn(
encoder.set_output_array(d_q);
encoder.set_output_array(d_k);
encoder.set_output_array(d_v);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
// Search cache.
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
@@ -457,6 +476,9 @@ void sdpa_backward_cudnn(
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
if (mask_arr) {
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
}
execute_graph(encoder, handle, graph, variant_pack);
}
@@ -498,7 +520,11 @@ bool ScaledDotProductAttention::use_fallback(
return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
!supports_sdpa_cudnn(q, k, v, s);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return false;
}
void ScaledDotProductAttention::eval_gpu(
@@ -516,6 +542,11 @@ void ScaledDotProductAttention::eval_gpu(
bool has_mask = inputs.size() - has_sinks_ > 3;
bool has_arr_mask = has_mask && !do_causal_;
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
if (supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
if (has_sinks_) {
@@ -524,7 +555,17 @@ void ScaledDotProductAttention::eval_gpu(
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
}
} else {
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
sdpa_cudnn(
q,
k,
v,
scale_,
out,
stats,
do_causal_,
mask_arr,
output_logsumexp_,
s);
}
}
@@ -544,13 +585,21 @@ void ScaledDotProductAttentionVJP::eval_gpu(
auto& s = stream();
assert(inputs.size() == 6);
assert(inputs.size() >= 6);
int primals_size = inputs.size() - 3;
bool has_arr_mask = primals_size > 3 + has_sinks_;
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
array o = prepare_sdpa_input(inputs[3], s);
array stats = prepare_sdpa_input(inputs[4], s);
array d_o = prepare_sdpa_input(inputs[5], s);
array o = prepare_sdpa_input(inputs[primals_size], s);
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
assert(outputs.size() == 3);
auto& d_q = outputs[0];
@@ -558,7 +607,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
auto& d_v = outputs[2];
sdpa_backward_cudnn(
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
}
} // namespace fast

View File

@@ -569,6 +569,10 @@ bool ScaledDotProductAttention::use_fallback(
return !(supports_sdpa_full || supports_sdpa_vector);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return true;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {

View File

@@ -36,6 +36,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
return true;
}
bool fast::ScaledDotProductAttention::supports_bool_mask() {
return false;
}
bool fast::ScaledDotProductAttentionVJP::use_fallback(
const array& q,
Stream s) {

View File

@@ -800,6 +800,15 @@ array scaled_dot_product_attention(
is_training,
output_logsumexp,
stream)) {
if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {
// Convert bool mask to additive mask.
float inf = std::numeric_limits<float>::infinity();
array& mask = inputs[3];
mask = where(
mask,
full_like(mask, 0, final_type, s),
full_like(mask, -inf, final_type, s));
}
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
auto primitive = std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
@@ -839,7 +848,7 @@ std::vector<array> ScaledDotProductAttention::vjp(
std::vector<Shape> shapes;
std::vector<Dtype> dtypes;
for (int i = 0; i < primals.size(); ++i) {
for (int i = 0; i < /* outputs size */ 3; ++i) {
shapes.push_back(primals[i].shape());
dtypes.push_back(primals[i].dtype());
}

View File

@@ -228,6 +228,7 @@ class ScaledDotProductAttention : public Custom {
bool is_training,
bool output_logsumexp,
Stream s);
static bool supports_bool_mask();
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {

View File

@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
def test_sdpa_grad(self):
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
f4 = lambda q, k, v: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
).sum()
# High tolerance due to cuDNN SDPA kernel requiring tf32.
tolerance = {"rtol": 1e-2, "atol": 1e-2}
def test_vjp(slow, fast, primals):
cotan = mx.ones_like(primals[0])
o1, vjp1 = mx.vjp(slow, primals, [cotan])
o2, vjp2 = mx.vjp(fast, primals, [cotan])
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
def test_grad(slow, fast, args):
g1 = mx.grad(slow)(*args)
g2 = mx.grad(fast)(*args)
self.assertTrue(mx.allclose(g1, g2, **tolerance))
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_mask_fast = lambda q, k, v, mask: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
).sum()
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
for N_q in (8, 32):
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
cotan = mx.ones_like(q)
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
for mask in (mask_additive, mask_bool):
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
g1 = mx.grad(f3)(q, k, v)
g2 = mx.grad(f4)(q, k, v)
for mask in (None, "causal"):
sdpa_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
self.assertTrue(mx.allclose(g1, g2, **tolerance))
loss_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
).sum()
test_grad(loss_slow, loss_fast, [q, k, v])
if __name__ == "__main__":