diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index 4cc94f430..556e64990 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -107,6 +107,8 @@ struct SDPACacheKey { std::array k_strides; std::array v_strides; bool do_causal; + std::array mask_shape; + std::array mask_strides; bool output_logsumexp; }; @@ -116,6 +118,7 @@ inline BytesKey build_sdpa_cache_key( const array& k, const array& v, bool do_causal, + const std::optional& mask_arr, bool output_logsumexp = true) { BytesKey cache_key; cache_key.pod = { @@ -128,8 +131,14 @@ inline BytesKey build_sdpa_cache_key( vector_key(k.strides()), vector_key(v.strides()), do_causal, + {}, + {}, output_logsumexp, }; + if (mask_arr) { + cache_key.pod.mask_shape = vector_key(mask_arr->shape()); + cache_key.pod.mask_strides = vector_key(mask_arr->strides()); + } return cache_key; } @@ -150,6 +159,7 @@ enum UIDS { K, V, SCALE, + BIAS, O, STATS, // Backward graph: @@ -165,6 +175,7 @@ fe::graph::Graph build_sdpa_graph( const array& k, const array& v, bool do_causal, + const std::optional& mask_arr, bool output_logsumexp, const array& o, const array& stats) { @@ -198,6 +209,11 @@ fe::graph::Graph build_sdpa_graph( .set_attn_scale(scale) .set_causal_mask(do_causal) .set_generate_stats(output_logsumexp); + if (mask_arr) { + auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS")); + set_tensor_attrs(bias_, BIAS, *mask_arr); + options.set_bias(bias_); + } auto [o_, stats_] = graph.sdpa(q_, k_, v_, options); o_->set_output(true); @@ -224,6 +240,7 @@ fe::graph::Graph build_sdpa_backward_graph( const array& k, const array& v, bool do_causal, + const std::optional& mask_arr, const array& o, const array& d_o, const array& stats, @@ -266,6 +283,11 @@ fe::graph::Graph build_sdpa_backward_graph( .set_name("sdpa_backward_cudnn") .set_attn_scale(scale) .set_causal_mask(do_causal); + if (mask_arr) { + auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS")); + set_tensor_attrs(bias_, BIAS, *mask_arr); + options.set_bias(bias_); + } auto [d_q_, d_k_, d_v_] = graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options); @@ -318,8 +340,6 @@ bool supports_sdpa_cudnn( const array& q, const array& k, const array& v, - bool has_mask, - bool do_causal, Stream s) { static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1); if (!enabled) { @@ -331,17 +351,6 @@ bool supports_sdpa_cudnn( return false; } - if (has_mask) { - // TODO: Support array masks. - if (!do_causal) { - return false; - } - // FIXME: Causal mask generates wrong results when L_Q != L_K. - if (q.shape(2) != k.shape(2)) { - return false; - } - } - // Only use cuDNN for prefilling and training. if (q.shape(2) != k.shape(2)) { return false; @@ -365,6 +374,7 @@ void sdpa_cudnn( array& o, array& stats, bool do_causal, + const std::optional& mask_arr, bool output_logsumexp, Stream s) { auto& encoder = cu::get_command_encoder(s); @@ -376,19 +386,21 @@ void sdpa_cudnn( encoder.set_input_array(k); encoder.set_input_array(v); encoder.set_output_array(o); - + if (mask_arr) { + encoder.set_input_array(*mask_arr); + } if (output_logsumexp) { stats.set_data(cu::malloc_async(stats.nbytes(), encoder)); encoder.set_output_array(stats); } // Search cache. - auto cache_key = - build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp); + auto cache_key = build_sdpa_cache_key( + encoder, q, k, v, do_causal, mask_arr, output_logsumexp); auto it = sdpa_cache().find(cache_key); if (it == sdpa_cache().end()) { auto graph = build_sdpa_graph( - handle, q, k, v, do_causal, output_logsumexp, o, stats); + handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats); it = sdpa_cache().emplace(cache_key, std::move(graph)).first; } auto& graph = it->second; @@ -399,6 +411,9 @@ void sdpa_cudnn( {V, const_cast(gpu_ptr(v))}, {SCALE, &scale}, {O, gpu_ptr(o)}}; + if (mask_arr) { + variant_pack[BIAS] = const_cast(gpu_ptr(*mask_arr)); + } if (output_logsumexp) { variant_pack[STATS] = gpu_ptr(stats); } @@ -414,6 +429,7 @@ void sdpa_backward_cudnn( const array& o, const array& stats, bool do_causal, + const std::optional& mask_arr, const array& d_o, array& d_q, array& d_k, @@ -435,13 +451,16 @@ void sdpa_backward_cudnn( encoder.set_output_array(d_q); encoder.set_output_array(d_k); encoder.set_output_array(d_v); + if (mask_arr) { + encoder.set_input_array(*mask_arr); + } // Search cache. - auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal); + auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr); auto it = sdpa_backward_cache().find(cache_key); if (it == sdpa_backward_cache().end()) { auto graph = build_sdpa_backward_graph( - handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v); + handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v); it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first; } auto& graph = it->second; @@ -457,6 +476,9 @@ void sdpa_backward_cudnn( {D_Q, gpu_ptr(d_q)}, {D_K, gpu_ptr(d_k)}, {D_V, gpu_ptr(d_v)}}; + if (mask_arr) { + variant_pack[BIAS] = const_cast(gpu_ptr(*mask_arr)); + } execute_graph(encoder, handle, graph, variant_pack); } @@ -498,7 +520,11 @@ bool ScaledDotProductAttention::use_fallback( return !supports_sdpa_vector( q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) && - !supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s); + !supports_sdpa_cudnn(q, k, v, s); +} + +bool ScaledDotProductAttention::supports_bool_mask() { + return false; } void ScaledDotProductAttention::eval_gpu( @@ -516,6 +542,11 @@ void ScaledDotProductAttention::eval_gpu( bool has_mask = inputs.size() - has_sinks_ > 3; bool has_arr_mask = has_mask && !do_causal_; + std::optional mask_arr; + if (has_arr_mask) { + mask_arr = prepare_sdpa_input(inputs[3], s); + } + if (supports_sdpa_vector( q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) { if (has_sinks_) { @@ -524,7 +555,17 @@ void ScaledDotProductAttention::eval_gpu( sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); } } else { - sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s); + sdpa_cudnn( + q, + k, + v, + scale_, + out, + stats, + do_causal_, + mask_arr, + output_logsumexp_, + s); } } @@ -544,13 +585,21 @@ void ScaledDotProductAttentionVJP::eval_gpu( auto& s = stream(); - assert(inputs.size() == 6); + assert(inputs.size() >= 6); + int primals_size = inputs.size() - 3; + bool has_arr_mask = primals_size > 3 + has_sinks_; + array q = prepare_sdpa_input(inputs[0], s); array k = prepare_sdpa_input(inputs[1], s); array v = prepare_sdpa_input(inputs[2], s); - array o = prepare_sdpa_input(inputs[3], s); - array stats = prepare_sdpa_input(inputs[4], s); - array d_o = prepare_sdpa_input(inputs[5], s); + array o = prepare_sdpa_input(inputs[primals_size], s); + array stats = prepare_sdpa_input(inputs[primals_size + 1], s); + array d_o = prepare_sdpa_input(inputs[primals_size + 2], s); + + std::optional mask_arr; + if (has_arr_mask) { + mask_arr = prepare_sdpa_input(inputs[3], s); + } assert(outputs.size() == 3); auto& d_q = outputs[0]; @@ -558,7 +607,7 @@ void ScaledDotProductAttentionVJP::eval_gpu( auto& d_v = outputs[2]; sdpa_backward_cudnn( - q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s); + q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s); } } // namespace fast diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 731001a15..96b2a796c 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -569,6 +569,10 @@ bool ScaledDotProductAttention::use_fallback( return !(supports_sdpa_full || supports_sdpa_vector); } +bool ScaledDotProductAttention::supports_bool_mask() { + return true; +} + void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 25432165c..40695283c 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -36,6 +36,10 @@ bool fast::ScaledDotProductAttention::use_fallback( return true; } +bool fast::ScaledDotProductAttention::supports_bool_mask() { + return false; +} + bool fast::ScaledDotProductAttentionVJP::use_fallback( const array& q, Stream s) { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 378bb22ce..1ad42d0cf 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -800,6 +800,15 @@ array scaled_dot_product_attention( is_training, output_logsumexp, stream)) { + if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) { + // Convert bool mask to additive mask. + float inf = std::numeric_limits::infinity(); + array& mask = inputs[3]; + mask = where( + mask, + full_like(mask, 0, final_type, s), + full_like(mask, -inf, final_type, s)); + } Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto primitive = std::make_shared( stream, fallback, scale, do_causal, has_sinks, output_logsumexp); @@ -839,7 +848,7 @@ std::vector ScaledDotProductAttention::vjp( std::vector shapes; std::vector dtypes; - for (int i = 0; i < primals.size(); ++i) { + for (int i = 0; i < /* outputs size */ 3; ++i) { shapes.push_back(primals[i].shape()); dtypes.push_back(primals[i].dtype()); } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 6d8208e1d..443483087 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -228,6 +228,7 @@ class ScaledDotProductAttention : public Custom { bool is_training, bool output_logsumexp, Stream s); + static bool supports_bool_mask(); void eval_cpu(const std::vector& inputs, std::vector& outputs) override { diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 4d2abaa33..d7ad6070a 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out, expected, atol=1e-5)) def test_sdpa_grad(self): - B, N_kv, T, D = (2, 8, 128, 64) - scale = D**-0.5 - - f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) - f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) - - f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum() - f4 = lambda q, k, v: ( - mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) - ).sum() - # High tolerance due to cuDNN SDPA kernel requiring tf32. tolerance = {"rtol": 1e-2, "atol": 1e-2} + def test_vjp(slow, fast, primals): + cotan = mx.ones_like(primals[0]) + o1, vjp1 = mx.vjp(slow, primals, [cotan]) + o2, vjp2 = mx.vjp(fast, primals, [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance)) + + def test_grad(slow, fast, args): + g1 = mx.grad(slow)(*args) + g2 = mx.grad(fast)(*args) + + self.assertTrue(mx.allclose(g1, g2, **tolerance)) + + sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn( + q, k, v, scale=scale, mask=mask + ) + sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + + loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn( + q, k, v, scale=scale, mask=mask + ).sum() + loss_mask_fast = lambda q, k, v, mask: ( + mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + ).sum() + + B, N_kv, T, D = (2, 8, 128, 64) + scale = D**-0.5 + for N_q in (8, 32): q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16) k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16) v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16) - cotan = mx.ones_like(q) - o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan]) - o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan]) + mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16) + mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5 - self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) - for i in range(3): - self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance)) + for mask in (mask_additive, mask_bool): + test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask]) + test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask]) - g1 = mx.grad(f3)(q, k, v) - g2 = mx.grad(f4)(q, k, v) + for mask in (None, "causal"): + sdpa_slow = lambda q, k, v: mlx_ref_attn( + q, k, v, scale=scale, mask=mask + ) + sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ) + test_vjp(sdpa_slow, sdpa_fast, [q, k, v]) - self.assertTrue(mx.allclose(g1, g2, **tolerance)) + loss_slow = lambda q, k, v: mlx_ref_attn( + q, k, v, scale=scale, mask=mask + ).sum() + loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask=mask + ).sum() + test_grad(loss_slow, loss_fast, [q, k, v]) if __name__ == "__main__":