From e5d35aa1878096d5f84d068e1934a3fc93ae8955 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 8 Apr 2025 19:13:54 -0700 Subject: [PATCH] no sdpa in grad (#2054) --- mlx/fast.cpp | 3 ++- mlx/transforms.cpp | 7 ++++--- mlx/transforms_impl.h | 20 ++++++++++++++++---- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 994d50f10..4799fedc3 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -9,6 +9,7 @@ #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/transforms.h" +#include "mlx/transforms_impl.h" namespace mlx::core::fast { @@ -772,7 +773,7 @@ array scaled_dot_product_attention( mask_shape.back() = keys.shape(-2); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); } - if (implementation_supports_use_case) { + if (!detail::in_grad_tracing() && implementation_supports_use_case) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 54f3b302b..b305257f0 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -42,7 +42,8 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector detail::InTracing::trace_stack{}; +std::vector> detail::InTracing::trace_stack{}; +int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { @@ -307,7 +308,7 @@ std::pair, std::vector> vjp( const std::vector& cotans, const std::vector& argnums) { // Set the global tracing flag. - detail::InTracing in_tracing; + detail::InTracing in_tracing{false, true}; // Make tracers from given primals std::vector primals_; @@ -505,7 +506,7 @@ std::pair, std::vector> jvp( const std::vector& primals, const std::vector& tangents) { // Set the global tracing flag. - detail::InTracing in_tracing; + detail::InTracing in_tracing{false, true}; if (primals.size() != tangents.size()) { throw std::invalid_argument( diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 3aa84bde4..7f62c406b 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -20,10 +20,12 @@ std::vector vmap_replace( // of the codebase that we are during tracing so evals should not throw away // the graph. struct InTracing { - explicit InTracing(bool dynamic = false) { - trace_stack.push_back(dynamic); + explicit InTracing(bool dynamic = false, bool grad = false) { + grad_counter += grad; + trace_stack.push_back({dynamic, grad}); } ~InTracing() { + grad_counter -= trace_stack.back().second; trace_stack.pop_back(); } @@ -32,11 +34,16 @@ struct InTracing { } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front(); + return in_tracing() && trace_stack.front().first; + } + + static bool in_grad_tracing() { + return grad_counter > 0; } private: - static std::vector trace_stack; + static int grad_counter; + static std::vector> trace_stack; }; struct RetainGraph { @@ -67,6 +74,11 @@ inline bool in_dynamic_tracing() { return detail::InTracing::in_dynamic_tracing(); } +/** Return true if we are in a gradient trace (vjp, jvp, etc). */ +inline bool in_grad_tracing() { + return detail::InTracing::in_grad_tracing(); +} + inline bool retain_graph() { return detail::RetainGraph::retain_graph(); }