mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
no sdpa in grad (#2054)
This commit is contained in:
@@ -20,10 +20,12 @@ std::vector<array> vmap_replace(
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
struct InTracing {
|
||||
explicit InTracing(bool dynamic = false) {
|
||||
trace_stack.push_back(dynamic);
|
||||
explicit InTracing(bool dynamic = false, bool grad = false) {
|
||||
grad_counter += grad;
|
||||
trace_stack.push_back({dynamic, grad});
|
||||
}
|
||||
~InTracing() {
|
||||
grad_counter -= trace_stack.back().second;
|
||||
trace_stack.pop_back();
|
||||
}
|
||||
|
||||
@@ -32,11 +34,16 @@ struct InTracing {
|
||||
}
|
||||
static bool in_dynamic_tracing() {
|
||||
// compile is always and only the outer-most transform
|
||||
return in_tracing() && trace_stack.front();
|
||||
return in_tracing() && trace_stack.front().first;
|
||||
}
|
||||
|
||||
static bool in_grad_tracing() {
|
||||
return grad_counter > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static std::vector<char> trace_stack;
|
||||
static int grad_counter;
|
||||
static std::vector<std::pair<char, char>> trace_stack;
|
||||
};
|
||||
|
||||
struct RetainGraph {
|
||||
@@ -67,6 +74,11 @@ inline bool in_dynamic_tracing() {
|
||||
return detail::InTracing::in_dynamic_tracing();
|
||||
}
|
||||
|
||||
/** Return true if we are in a gradient trace (vjp, jvp, etc). */
|
||||
inline bool in_grad_tracing() {
|
||||
return detail::InTracing::in_grad_tracing();
|
||||
}
|
||||
|
||||
inline bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user