From 6f35017d1baf7764a01893dc54e818dd5cc94aa5 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 19 Nov 2025 08:13:50 +0900 Subject: [PATCH] [CUDA] cuDNN backward attention (#2762) --- mlx/backend/cuda/CMakeLists.txt | 2 +- .../cuda/scaled_dot_product_attention.cpp | 318 +++++++++++++++--- .../cuda/scaled_dot_product_attention.cu | 7 +- .../metal/scaled_dot_product_attention.cpp | 24 +- mlx/backend/no_gpu/primitives.cpp | 11 +- mlx/fast.cpp | 83 ++++- mlx/fast_primitives.h | 76 ++++- python/tests/test_fast_sdpa.py | 33 ++ 8 files changed, 472 insertions(+), 82 deletions(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 6b440ebc7..e11a18f95 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -168,7 +168,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) FetchContent_Declare( cudnn GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git - GIT_TAG v1.14.0 + GIT_TAG v1.16.0 GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cpp b/mlx/backend/cuda/scaled_dot_product_attention.cpp index bf40953c0..b588603eb 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cpp +++ b/mlx/backend/cuda/scaled_dot_product_attention.cpp @@ -5,7 +5,6 @@ #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" -#include "mlx/transforms_impl.h" #include @@ -26,6 +25,11 @@ namespace { std::vector normalized_strides(const array& x) { std::vector strides(x.strides().begin(), x.strides().end()); + if (std::all_of( + strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) { + strides.back() = 1; + return strides; + } if (!x.flags().row_contiguous || x.ndim() < 2) { return strides; } @@ -71,11 +75,41 @@ struct SDPACacheKey { std::array k_strides; std::array v_strides; bool do_causal; + bool output_logsumexp; }; +inline BytesKey build_sdpa_cache_key( + cu::CommandEncoder& encoder, + const array& q, + const array& k, + const array& v, + bool do_causal, + bool output_logsumexp = true) { + BytesKey cache_key; + cache_key.pod = { + encoder.device().cuda_device(), + dtype_to_cudnn_type(q.dtype()), + vector_key(q.shape()), + vector_key(k.shape()), + vector_key(v.shape()), + vector_key(q.strides()), + vector_key(k.strides()), + vector_key(v.strides()), + do_causal, + output_logsumexp, + }; + return cache_key; +} + auto& sdpa_cache() { static LRUBytesKeyCache cache( - "MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128); + "MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16); + return cache; +} + +auto& sdpa_backward_cache() { + static LRUBytesKeyCache cache( + "MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16); return cache; } @@ -85,6 +119,12 @@ enum UIDS { V, SCALE, O, + STATS, + // Backward graph: + D_Q, + D_K, + D_V, + D_O, }; fe::graph::Graph build_sdpa_graph( @@ -93,7 +133,9 @@ fe::graph::Graph build_sdpa_graph( const array& k, const array& v, bool do_causal, - const array& o) { + bool output_logsumexp, + const array& o, + const array& stats) { auto dtype = fe::DataType_t::HALF; if (q.dtype() == bfloat16) { dtype = fe::DataType_t::BFLOAT16; @@ -119,15 +161,19 @@ fe::graph::Graph build_sdpa_graph( .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto sdpa_options = fe::graph::SDPA_attributes() - .set_name("sdpa_cudnn") - .set_attn_scale(scale) - .set_causal_mask(do_causal) - .set_generate_stats(false); + auto options = fe::graph::SDPA_attributes() + .set_name("sdpa_cudnn") + .set_attn_scale(scale) + .set_causal_mask(do_causal) + .set_generate_stats(output_logsumexp); - auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options); + auto [o_, stats_] = graph.sdpa(q_, k_, v_, options); o_->set_output(true); set_tensor_attrs(o_, O, o); + if (output_logsumexp) { + stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT); + set_tensor_attrs(stats_, STATS, stats); + } CHECK_CUDNN_FE_ERROR(graph.validate()); CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle)); @@ -140,6 +186,100 @@ fe::graph::Graph build_sdpa_graph( return graph; } +fe::graph::Graph build_sdpa_backward_graph( + cudnnHandle_t handle, + const array& q, + const array& k, + const array& v, + bool do_causal, + const array& o, + const array& d_o, + const array& stats, + array& d_q, + array& d_k, + array& d_v) { + auto dtype = fe::DataType_t::HALF; + if (q.dtype() == bfloat16) { + dtype = fe::DataType_t::BFLOAT16; + } + + fe::graph::Graph graph; + graph.set_io_data_type(dtype) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q")); + auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K")); + auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V")); + auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O")); + auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O")); + auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS")); + set_tensor_attrs(q_, Q, q); + set_tensor_attrs(k_, K, k); + set_tensor_attrs(v_, V, v); + set_tensor_attrs(o_, O, o); + set_tensor_attrs(d_o_, D_O, d_o); + set_tensor_attrs(stats_, STATS, stats); + stats_->set_data_type(fe::DataType_t::FLOAT); + + auto scale = graph.tensor(fe::graph::Tensor_attributes() + .set_name("Scale") + .set_uid(SCALE) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + auto options = fe::graph::SDPA_backward_attributes() + .set_name("sdpa_backward_cudnn") + .set_attn_scale(scale) + .set_causal_mask(do_causal); + + auto [d_q_, d_k_, d_v_] = + graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options); + d_q_->set_output(true); + d_k_->set_output(true); + d_v_->set_output(true); + set_tensor_attrs(d_q_, D_Q, d_q); + set_tensor_attrs(d_k_, D_K, d_k); + set_tensor_attrs(d_v_, D_V, d_v); + + CHECK_CUDNN_FE_ERROR(graph.validate()); + CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle)); + CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A})); + graph.select_behavior_notes( + {fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); + CHECK_CUDNN_FE_ERROR(graph.check_support(handle)); + CHECK_CUDNN_FE_ERROR(graph.build_plans(handle)); + + return graph; +} + +void execute_graph( + cu::CommandEncoder& encoder, + cudnnHandle_t handle, + fe::graph::Graph& graph, + std::unordered_map& variant_pack) { + int64_t workspace_size = 0; + CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size)); + void* workspace_ptr = nullptr; + if (workspace_size > 0) { + array workspace( + cu::malloc_async(workspace_size, encoder), + {static_cast(workspace_size)}, + uint8); + encoder.add_temporary(workspace); + workspace_ptr = gpu_ptr(workspace); + } + + cudnnSetStream(handle, encoder.stream()); + + CudaGraph cuda_graph(encoder.device()); + CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph( + handle, variant_pack, workspace_ptr, cuda_graph)); + encoder.add_graph_node(cuda_graph); +} + } // namespace bool supports_sdpa_cudnn( @@ -170,7 +310,7 @@ bool supports_sdpa_cudnn( } } - // Only use cuDNN for prefilling. + // Only use cuDNN for prefilling and training. if (q.shape(2) != k.shape(2)) { return false; } @@ -191,9 +331,13 @@ void sdpa_cudnn( const array& v, float scale, array& o, + array& stats, bool do_causal, + bool output_logsumexp, Stream s) { auto& encoder = cu::get_command_encoder(s); + auto handle = encoder.device().cudnn_handle(); + // TODO: Handle donation. // TODO: Make O use same memory layout with Q. o.set_data(cu::malloc_async(o.nbytes(), encoder)); @@ -203,28 +347,19 @@ void sdpa_cudnn( encoder.set_input_array(v); encoder.set_output_array(o); - auto handle = encoder.device().cudnn_handle(); - cudnnSetStream(handle, encoder.stream()); + if (output_logsumexp) { + stats.set_data(cu::malloc_async(stats.nbytes(), encoder)); + encoder.set_output_array(stats); + } // Search cache. - BytesKey cache_key; - cache_key.pod = { - encoder.device().cuda_device(), - dtype_to_cudnn_type(q.dtype()), - vector_key(q.shape()), - vector_key(k.shape()), - vector_key(v.shape()), - vector_key(q.strides()), - vector_key(k.strides()), - vector_key(v.strides()), - do_causal, - }; + auto cache_key = + build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp); auto it = sdpa_cache().find(cache_key); if (it == sdpa_cache().end()) { - it = - sdpa_cache() - .emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o)) - .first; + auto graph = build_sdpa_graph( + handle, q, k, v, do_causal, output_logsumexp, o, stats); + it = sdpa_cache().emplace(cache_key, std::move(graph)).first; } auto& graph = it->second; @@ -234,23 +369,67 @@ void sdpa_cudnn( {V, const_cast(gpu_ptr(v))}, {SCALE, &scale}, {O, gpu_ptr(o)}}; - - int64_t workspace_size = 0; - CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size)); - void* workspace_ptr = nullptr; - if (workspace_size > 0) { - array workspace( - cu::malloc_async(workspace_size, encoder), - {static_cast(workspace_size)}, - uint8); - encoder.add_temporary(workspace); - workspace_ptr = gpu_ptr(workspace); + if (output_logsumexp) { + variant_pack[STATS] = gpu_ptr(stats); } - CudaGraph cuda_graph(encoder.device()); - CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph( - handle, variant_pack, workspace_ptr, cuda_graph)); - encoder.add_graph_node(cuda_graph); + execute_graph(encoder, handle, graph, variant_pack); +} + +void sdpa_backward_cudnn( + const array& q, + const array& k, + const array& v, + float scale, + const array& o, + const array& stats, + bool do_causal, + const array& d_o, + array& d_q, + array& d_k, + array& d_v, + Stream s) { + auto& encoder = cu::get_command_encoder(s); + auto handle = encoder.device().cudnn_handle(); + + // TODO: Handle donation. + d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder)); + d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder)); + d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder)); + + encoder.set_input_array(q); + encoder.set_input_array(k); + encoder.set_input_array(v); + encoder.set_input_array(o); + encoder.set_input_array(stats); + encoder.set_input_array(d_o); + encoder.set_output_array(d_q); + encoder.set_output_array(d_k); + encoder.set_output_array(d_v); + + // Search cache. + auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal); + auto it = sdpa_backward_cache().find(cache_key); + if (it == sdpa_backward_cache().end()) { + auto graph = build_sdpa_backward_graph( + handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v); + it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first; + } + auto& graph = it->second; + + std::unordered_map variant_pack{ + {Q, const_cast(gpu_ptr(q))}, + {K, const_cast(gpu_ptr(k))}, + {V, const_cast(gpu_ptr(v))}, + {SCALE, &scale}, + {O, const_cast(gpu_ptr(o))}, + {STATS, const_cast(gpu_ptr(stats))}, + {D_O, const_cast(gpu_ptr(d_o))}, + {D_Q, gpu_ptr(d_q)}, + {D_K, gpu_ptr(d_k)}, + {D_V, gpu_ptr(d_v)}}; + + execute_graph(encoder, handle, graph, variant_pack); } // Defined in scaled_dot_product_attention.cu file. @@ -260,7 +439,8 @@ bool supports_sdpa_vector( const array& v, bool has_mask, bool has_arr_mask, - bool do_causal); + bool do_causal, + bool output_logsumexp); void sdpa_vector( const array& q, const array& k, @@ -280,21 +460,21 @@ bool ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, + bool is_training, + bool output_logsumexp, Stream s) { - if (detail::in_grad_tracing()) { - return true; - } if (s.device == Device::cpu) { return true; } - return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) && + return !supports_sdpa_vector( + q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) && !supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s); } void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, - array& out) { + std::vector& outputs) { nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); auto& s = stream(); @@ -302,20 +482,56 @@ void ScaledDotProductAttention::eval_gpu( array q = prepare_sdpa_input(inputs[0], s); array k = prepare_sdpa_input(inputs[1], s); array v = prepare_sdpa_input(inputs[2], s); + auto& out = outputs[0]; + auto& stats = outputs[1]; bool has_mask = inputs.size() - has_sinks_ > 3; bool has_arr_mask = has_mask && !do_causal_; - if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) { + if (supports_sdpa_vector( + q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) { if (has_sinks_) { sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); } else { sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); } } else { - sdpa_cudnn(q, k, v, scale_, out, do_causal_, s); + sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s); } } +bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { + // The frontend adds a padding mask when sequence length is not a multiple of + // tile size. + if (q.shape(2) % 128 != 0) { + return true; + } + return s.device == Device::cpu; +} + +void ScaledDotProductAttentionVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu"); + + auto& s = stream(); + + assert(inputs.size() == 6); + array q = prepare_sdpa_input(inputs[0], s); + array k = prepare_sdpa_input(inputs[1], s); + array v = prepare_sdpa_input(inputs[2], s); + array o = prepare_sdpa_input(inputs[3], s); + array stats = prepare_sdpa_input(inputs[4], s); + array d_o = prepare_sdpa_input(inputs[5], s); + + assert(outputs.size() == 3); + auto& d_q = outputs[0]; + auto& d_k = outputs[1]; + auto& d_v = outputs[2]; + + sdpa_backward_cudnn( + q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s); +} + } // namespace fast } // namespace mlx::core diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 57108d443..8ea828c09 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -664,7 +664,12 @@ bool supports_sdpa_vector( const array& v, bool has_mask, bool has_arr_mask, - bool do_causal) { + bool do_causal, + bool output_logsumexp) { + if (output_logsumexp) { + return false; + } + const int value_head_dim = v.shape(-1); const int query_head_dim = q.shape(-1); const int query_sequence_length = q.shape(2); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 0aca3170e..d8adf8199 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -7,7 +7,6 @@ #include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" -#include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core::fast { @@ -379,8 +378,15 @@ bool ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, + bool is_training, + bool output_logsumexp, Stream s) { - if (detail::in_grad_tracing()) { + if (is_training) { + // It's faster for training on Metal to use the unfused SDPA for both + // forward and backward. + return true; + } + if (output_logsumexp) { return true; } if (s.device == Device::cpu) { @@ -414,14 +420,14 @@ bool ScaledDotProductAttention::use_fallback( void ScaledDotProductAttention::eval_gpu( const std::vector& inputs, - array& out) { + std::vector& outputs) { auto& s = stream(); auto& d = metal::device(s.device); auto& q_pre = inputs[0]; auto& k_pre = inputs[1]; auto& v_pre = inputs[2]; - auto& o = out; + auto& o = outputs[0]; std::vector copies; @@ -553,4 +559,14 @@ void ScaledDotProductAttention::eval_gpu( d.add_temporaries(std::move(copies), s.index); } +bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { + return true; +} + +void ScaledDotProductAttentionVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("NYI"); +} + } // namespace mlx::core::fast diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 406a627b9..ba625947c 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -30,6 +30,14 @@ bool fast::ScaledDotProductAttention::use_fallback( bool has_mask, bool has_arr_mask, bool do_causal, + bool is_training, + bool output_logsumexp, + Stream s) { + return true; +} + +bool fast::ScaledDotProductAttentionVJP::use_fallback( + const array& q, Stream s) { return true; } @@ -153,7 +161,8 @@ NO_GPU_MULTI(LayerNormVJP) NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) -NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(ScaledDotProductAttention) +NO_GPU_MULTI(ScaledDotProductAttentionVJP) NO_GPU_MULTI(ConvertFP8) NO_GPU_MULTI(Quantize) NO_GPU_MULTI(CustomKernel) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index c7c572a08..b51c747d8 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -6,6 +6,7 @@ #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/transforms.h" +#include "mlx/transforms_impl.h" namespace mlx::core::fast { @@ -784,22 +785,88 @@ array scaled_dot_product_attention( inputs.push_back(astype(*sinks, final_type, stream)); } + bool is_training = detail::in_grad_tracing(); + bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback(q, stream); + bool output_logsumexp = is_training && has_fast_vjp; if (!ScaledDotProductAttention::use_fallback( - q, k, v, has_mask, has_arr_mask, do_causal, stream)) { - auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; - return array( - std::move(out_shape), - final_type, - std::make_shared( - stream, fallback, scale, do_causal, has_sinks), - std::move(inputs)); + q, + k, + v, + has_mask, + has_arr_mask, + do_causal, + is_training, + output_logsumexp, + stream)) { + Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; + auto primitive = std::make_shared( + stream, fallback, scale, do_causal, has_sinks, output_logsumexp); + if (output_logsumexp) { + return array::make_arrays( + {std::move(out_shape), Shape{q.shape(0), q.shape(1), q.shape(2), 1}}, + {final_type, float32}, + primitive, + std::move(inputs))[0]; + } else { + return array( + std::move(out_shape), final_type, primitive, std::move(inputs)); + } } return fallback(std::move(inputs))[0]; } +std::vector ScaledDotProductAttention::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + assert(primals.size() >= 3); + assert(cotangents.size() == outputs.size()); + + auto s = stream(); + if (ScaledDotProductAttentionVJP::use_fallback(primals[0], s)) { + assert(outputs.size() == 1); + return Custom::vjp(primals, cotangents, argnums, outputs); + } + + auto fallback = [sdpa = fallback_, s](const std::vector& inputs) { + std::vector primals(inputs.begin(), std::prev(inputs.end())); + auto [_, vjps] = mlx::core::vjp(sdpa, primals, {inputs.back()}); + return vjps; + }; + + std::vector shapes; + std::vector dtypes; + for (int i = 0; i < primals.size(); ++i) { + shapes.push_back(primals[i].shape()); + dtypes.push_back(primals[i].dtype()); + } + auto primitive = std::make_shared( + s, fallback, scale_, do_causal_, has_sinks_); + std::vector inputs = primals; + inputs.push_back(outputs[0]); + inputs.push_back(outputs[1]); + inputs.push_back(cotangents[0]); + auto vjps = array::make_arrays(std::move(shapes), dtypes, primitive, inputs); + + std::vector returned_vjps; + for (int arg : argnums) { + returned_vjps.push_back(std::move(vjps[arg])); + } + return returned_vjps; +} + bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); + return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ && + has_sinks_ == a_other.has_sinks_ && + output_logsumexp_ == a_other.output_logsumexp_; +} + +bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const { + const ScaledDotProductAttentionVJP& a_other = + static_cast(other); return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ && has_sinks_ == a_other.has_sinks_; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 584860e52..6d8208e1d 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -15,7 +15,7 @@ class Custom : public Primitive { explicit Custom( Stream stream, std::function(std::vector)> fallback) - : Primitive(stream), fallback_(fallback) {} + : Primitive(stream), fallback_(std::move(fallback)) {} virtual std::pair, std::vector> vmap( const std::vector& inputs, @@ -32,7 +32,7 @@ class Custom : public Primitive { const std::vector& argnums, const std::vector& outputs) override; - private: + protected: std::function(std::vector)> fallback_; }; @@ -42,7 +42,7 @@ class RMSNorm : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps) {} + : Custom(stream, std::move(fallback)), eps_(eps) {} static bool use_fallback(Stream stream); @@ -77,7 +77,7 @@ class RMSNormVJP : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps) {} + : Custom(stream, std::move(fallback)), eps_(eps) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -102,7 +102,7 @@ class LayerNorm : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps) {} + : Custom(stream, std::move(fallback)), eps_(eps) {} static bool use_fallback(Stream s); @@ -136,7 +136,7 @@ class LayerNormVJP : public Custom { Stream stream, std::function(std::vector)> fallback, float eps) - : Custom(stream, fallback), eps_(eps) {} + : Custom(stream, std::move(fallback)), eps_(eps) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -165,7 +165,7 @@ class RoPE : public Custom { float base, float scale, bool forward) - : Custom(stream, fallback), + : Custom(stream, std::move(fallback)), dims_(dims), traditional_(traditional), base_(base), @@ -205,16 +205,18 @@ class RoPE : public Custom { class ScaledDotProductAttention : public Custom { public: - explicit ScaledDotProductAttention( + ScaledDotProductAttention( Stream stream, std::function(std::vector)> fallback, float scale, bool do_causal, - bool has_sinks) - : Custom(stream, fallback), + bool has_sinks, + bool output_logsumexp) + : Custom(stream, std::move(fallback)), scale_(scale), do_causal_(do_causal), - has_sinks_(has_sinks) {} + has_sinks_(has_sinks), + output_logsumexp_(output_logsumexp) {} static bool use_fallback( const array& q, @@ -223,6 +225,8 @@ class ScaledDotProductAttention : public Custom { bool has_mask, bool has_arr_mask, bool do_causal, + bool is_training, + bool output_logsumexp, Stream s); void eval_cpu(const std::vector& inputs, std::vector& outputs) @@ -231,15 +235,55 @@ class ScaledDotProductAttention : public Custom { } void eval_gpu(const std::vector& inputs, std::vector& outputs) - override { - eval_gpu(inputs, outputs[0]); - } + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; - void eval_gpu(const std::vector& inputs, array& out); bool is_equivalent(const Primitive& other) const override; DEFINE_NAME(ScaledDotProductAttention); DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple( + nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_); + } + + private: + float scale_; + bool do_causal_; + bool has_sinks_; + bool output_logsumexp_; +}; + +class ScaledDotProductAttentionVJP : public Custom { + public: + ScaledDotProductAttentionVJP( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool do_causal, + bool has_sinks) + : Custom(stream, std::move(fallback)), + scale_(scale), + do_causal_(do_causal), + has_sinks_(has_sinks) {} + + static bool use_fallback(const array& q, Stream s); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(ScaledDotProductAttentionVJP); + bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_); } @@ -288,7 +332,7 @@ class Quantize : public Custom { int bits, QuantizationMode mode, bool dequantize) - : Custom(stream, fallback), + : Custom(stream, std::move(fallback)), group_size_(group_size), bits_(bits), mode_(mode), diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index e5c341241..4d2abaa33 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -738,6 +738,39 @@ class TestSDPA(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(out, expected, atol=1e-5)) + def test_sdpa_grad(self): + B, N_kv, T, D = (2, 8, 128, 64) + scale = D**-0.5 + + f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale) + f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + + f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum() + f4 = lambda q, k, v: ( + mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + ).sum() + + # High tolerance due to cuDNN SDPA kernel requiring tf32. + tolerance = {"rtol": 1e-2, "atol": 1e-2} + + for N_q in (8, 32): + q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16) + k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16) + v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16) + + cotan = mx.ones_like(q) + o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan]) + o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan]) + + self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance)) + for i in range(3): + self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance)) + + g1 = mx.grad(f3)(q, k, v) + g2 = mx.grad(f4)(q, k, v) + + self.assertTrue(mx.allclose(g1, g2, **tolerance)) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True)