mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
no sdpa in grad (#2054)
This commit is contained in:
parent
00794c42bc
commit
e5d35aa187
@ -9,6 +9,7 @@
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@ -772,7 +773,7 @@ array scaled_dot_product_attention(
|
||||
mask_shape.back() = keys.shape(-2);
|
||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||
}
|
||||
if (implementation_supports_use_case) {
|
||||
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
|
@ -42,7 +42,8 @@ class Synchronizer : public Primitive {
|
||||
// are currently under a function transformation and the retain_graph()
|
||||
// function which returns true if we are forced to retain the graph during
|
||||
// evaluation.
|
||||
std::vector<char> detail::InTracing::trace_stack{};
|
||||
std::vector<std::pair<char, char>> detail::InTracing::trace_stack{};
|
||||
int detail::InTracing::grad_counter{0};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
@ -307,7 +308,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::vector<array>& cotans,
|
||||
const std::vector<int>& argnums) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
detail::InTracing in_tracing{false, true};
|
||||
|
||||
// Make tracers from given primals
|
||||
std::vector<array> primals_;
|
||||
@ -505,7 +506,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
detail::InTracing in_tracing{false, true};
|
||||
|
||||
if (primals.size() != tangents.size()) {
|
||||
throw std::invalid_argument(
|
||||
|
@ -20,10 +20,12 @@ std::vector<array> vmap_replace(
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
struct InTracing {
|
||||
explicit InTracing(bool dynamic = false) {
|
||||
trace_stack.push_back(dynamic);
|
||||
explicit InTracing(bool dynamic = false, bool grad = false) {
|
||||
grad_counter += grad;
|
||||
trace_stack.push_back({dynamic, grad});
|
||||
}
|
||||
~InTracing() {
|
||||
grad_counter -= trace_stack.back().second;
|
||||
trace_stack.pop_back();
|
||||
}
|
||||
|
||||
@ -32,11 +34,16 @@ struct InTracing {
|
||||
}
|
||||
static bool in_dynamic_tracing() {
|
||||
// compile is always and only the outer-most transform
|
||||
return in_tracing() && trace_stack.front();
|
||||
return in_tracing() && trace_stack.front().first;
|
||||
}
|
||||
|
||||
static bool in_grad_tracing() {
|
||||
return grad_counter > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static std::vector<char> trace_stack;
|
||||
static int grad_counter;
|
||||
static std::vector<std::pair<char, char>> trace_stack;
|
||||
};
|
||||
|
||||
struct RetainGraph {
|
||||
@ -67,6 +74,11 @@ inline bool in_dynamic_tracing() {
|
||||
return detail::InTracing::in_dynamic_tracing();
|
||||
}
|
||||
|
||||
/** Return true if we are in a gradient trace (vjp, jvp, etc). */
|
||||
inline bool in_grad_tracing() {
|
||||
return detail::InTracing::in_grad_tracing();
|
||||
}
|
||||
|
||||
inline bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user