no sdpa in grad (#2054)

This commit is contained in:
Awni Hannun 2025-04-08 19:13:54 -07:00 committed by GitHub
parent 00794c42bc
commit e5d35aa187
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 8 deletions

View File

@ -9,6 +9,7 @@
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core::fast { namespace mlx::core::fast {
@ -772,7 +773,7 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2); mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
} }
if (implementation_supports_use_case) { if (!detail::in_grad_tracing() && implementation_supports_use_case) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array( return array(
std::move(out_shape), std::move(out_shape),

View File

@ -42,7 +42,8 @@ class Synchronizer : public Primitive {
// are currently under a function transformation and the retain_graph() // are currently under a function transformation and the retain_graph()
// function which returns true if we are forced to retain the graph during // function which returns true if we are forced to retain the graph during
// evaluation. // evaluation.
std::vector<char> detail::InTracing::trace_stack{}; std::vector<std::pair<char, char>> detail::InTracing::trace_stack{};
int detail::InTracing::grad_counter{0};
int detail::RetainGraph::tracing_counter{0}; int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) { array eval_impl(std::vector<array> outputs, bool async) {
@ -307,7 +308,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
const std::vector<array>& cotans, const std::vector<array>& cotans,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Set the global tracing flag. // Set the global tracing flag.
detail::InTracing in_tracing; detail::InTracing in_tracing{false, true};
// Make tracers from given primals // Make tracers from given primals
std::vector<array> primals_; std::vector<array> primals_;
@ -505,7 +506,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents) { const std::vector<array>& tangents) {
// Set the global tracing flag. // Set the global tracing flag.
detail::InTracing in_tracing; detail::InTracing in_tracing{false, true};
if (primals.size() != tangents.size()) { if (primals.size() != tangents.size()) {
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -20,10 +20,12 @@ std::vector<array> vmap_replace(
// of the codebase that we are during tracing so evals should not throw away // of the codebase that we are during tracing so evals should not throw away
// the graph. // the graph.
struct InTracing { struct InTracing {
explicit InTracing(bool dynamic = false) { explicit InTracing(bool dynamic = false, bool grad = false) {
trace_stack.push_back(dynamic); grad_counter += grad;
trace_stack.push_back({dynamic, grad});
} }
~InTracing() { ~InTracing() {
grad_counter -= trace_stack.back().second;
trace_stack.pop_back(); trace_stack.pop_back();
} }
@ -32,11 +34,16 @@ struct InTracing {
} }
static bool in_dynamic_tracing() { static bool in_dynamic_tracing() {
// compile is always and only the outer-most transform // compile is always and only the outer-most transform
return in_tracing() && trace_stack.front(); return in_tracing() && trace_stack.front().first;
}
static bool in_grad_tracing() {
return grad_counter > 0;
} }
private: private:
static std::vector<char> trace_stack; static int grad_counter;
static std::vector<std::pair<char, char>> trace_stack;
}; };
struct RetainGraph { struct RetainGraph {
@ -67,6 +74,11 @@ inline bool in_dynamic_tracing() {
return detail::InTracing::in_dynamic_tracing(); return detail::InTracing::in_dynamic_tracing();
} }
/** Return true if we are in a gradient trace (vjp, jvp, etc). */
inline bool in_grad_tracing() {
return detail::InTracing::in_grad_tracing();
}
inline bool retain_graph() { inline bool retain_graph() {
return detail::RetainGraph::retain_graph(); return detail::RetainGraph::retain_graph();
} }