mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
no sdpa in grad (#2054)
This commit is contained in:
parent
00794c42bc
commit
e5d35aa187
@ -9,6 +9,7 @@
|
|||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
@ -772,7 +773,7 @@ array scaled_dot_product_attention(
|
|||||||
mask_shape.back() = keys.shape(-2);
|
mask_shape.back() = keys.shape(-2);
|
||||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||||
}
|
}
|
||||||
if (implementation_supports_use_case) {
|
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
|
||||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
|
@ -42,7 +42,8 @@ class Synchronizer : public Primitive {
|
|||||||
// are currently under a function transformation and the retain_graph()
|
// are currently under a function transformation and the retain_graph()
|
||||||
// function which returns true if we are forced to retain the graph during
|
// function which returns true if we are forced to retain the graph during
|
||||||
// evaluation.
|
// evaluation.
|
||||||
std::vector<char> detail::InTracing::trace_stack{};
|
std::vector<std::pair<char, char>> detail::InTracing::trace_stack{};
|
||||||
|
int detail::InTracing::grad_counter{0};
|
||||||
int detail::RetainGraph::tracing_counter{0};
|
int detail::RetainGraph::tracing_counter{0};
|
||||||
|
|
||||||
array eval_impl(std::vector<array> outputs, bool async) {
|
array eval_impl(std::vector<array> outputs, bool async) {
|
||||||
@ -307,7 +308,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
const std::vector<array>& cotans,
|
const std::vector<array>& cotans,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Set the global tracing flag.
|
// Set the global tracing flag.
|
||||||
detail::InTracing in_tracing;
|
detail::InTracing in_tracing{false, true};
|
||||||
|
|
||||||
// Make tracers from given primals
|
// Make tracers from given primals
|
||||||
std::vector<array> primals_;
|
std::vector<array> primals_;
|
||||||
@ -505,7 +506,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
|||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents) {
|
const std::vector<array>& tangents) {
|
||||||
// Set the global tracing flag.
|
// Set the global tracing flag.
|
||||||
detail::InTracing in_tracing;
|
detail::InTracing in_tracing{false, true};
|
||||||
|
|
||||||
if (primals.size() != tangents.size()) {
|
if (primals.size() != tangents.size()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -20,10 +20,12 @@ std::vector<array> vmap_replace(
|
|||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
// the graph.
|
// the graph.
|
||||||
struct InTracing {
|
struct InTracing {
|
||||||
explicit InTracing(bool dynamic = false) {
|
explicit InTracing(bool dynamic = false, bool grad = false) {
|
||||||
trace_stack.push_back(dynamic);
|
grad_counter += grad;
|
||||||
|
trace_stack.push_back({dynamic, grad});
|
||||||
}
|
}
|
||||||
~InTracing() {
|
~InTracing() {
|
||||||
|
grad_counter -= trace_stack.back().second;
|
||||||
trace_stack.pop_back();
|
trace_stack.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,11 +34,16 @@ struct InTracing {
|
|||||||
}
|
}
|
||||||
static bool in_dynamic_tracing() {
|
static bool in_dynamic_tracing() {
|
||||||
// compile is always and only the outer-most transform
|
// compile is always and only the outer-most transform
|
||||||
return in_tracing() && trace_stack.front();
|
return in_tracing() && trace_stack.front().first;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool in_grad_tracing() {
|
||||||
|
return grad_counter > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static std::vector<char> trace_stack;
|
static int grad_counter;
|
||||||
|
static std::vector<std::pair<char, char>> trace_stack;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RetainGraph {
|
struct RetainGraph {
|
||||||
@ -67,6 +74,11 @@ inline bool in_dynamic_tracing() {
|
|||||||
return detail::InTracing::in_dynamic_tracing();
|
return detail::InTracing::in_dynamic_tracing();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Return true if we are in a gradient trace (vjp, jvp, etc). */
|
||||||
|
inline bool in_grad_tracing() {
|
||||||
|
return detail::InTracing::in_grad_tracing();
|
||||||
|
}
|
||||||
|
|
||||||
inline bool retain_graph() {
|
inline bool retain_graph() {
|
||||||
return detail::RetainGraph::retain_graph();
|
return detail::RetainGraph::retain_graph();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user