[CUDA] cuDNN backward attention (#2762)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Cheng
2025-11-19 08:13:50 +09:00
committed by GitHub
parent b167f0df1c
commit 6f35017d1b
8 changed files with 472 additions and 82 deletions

View File

@@ -168,7 +168,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0
GIT_TAG v1.16.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -5,7 +5,6 @@
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
@@ -26,6 +25,11 @@ namespace {
std::vector<int64_t> normalized_strides(const array& x) {
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
if (std::all_of(
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
strides.back() = 1;
return strides;
}
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
}
@@ -71,11 +75,41 @@ struct SDPACacheKey {
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
bool output_logsumexp;
};
inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
bool do_causal,
bool output_logsumexp = true) {
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(q.dtype()),
vector_key<QKV_NDIM>(q.shape()),
vector_key<QKV_NDIM>(k.shape()),
vector_key<QKV_NDIM>(v.shape()),
vector_key<QKV_NDIM>(q.strides()),
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
output_logsumexp,
};
return cache_key;
}
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128);
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
return cache;
}
auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
return cache;
}
@@ -85,6 +119,12 @@ enum UIDS {
V,
SCALE,
O,
STATS,
// Backward graph:
D_Q,
D_K,
D_V,
D_O,
};
fe::graph::Graph build_sdpa_graph(
@@ -93,7 +133,9 @@ fe::graph::Graph build_sdpa_graph(
const array& k,
const array& v,
bool do_causal,
const array& o) {
bool output_logsumexp,
const array& o,
const array& stats) {
auto dtype = fe::DataType_t::HALF;
if (q.dtype() == bfloat16) {
dtype = fe::DataType_t::BFLOAT16;
@@ -119,15 +161,19 @@ fe::graph::Graph build_sdpa_graph(
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto sdpa_options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(false);
auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(output_logsumexp);
auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options);
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
o_->set_output(true);
set_tensor_attrs(o_, O, o);
if (output_logsumexp) {
stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT);
set_tensor_attrs(stats_, STATS, stats);
}
CHECK_CUDNN_FE_ERROR(graph.validate());
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
@@ -140,6 +186,100 @@ fe::graph::Graph build_sdpa_graph(
return graph;
}
fe::graph::Graph build_sdpa_backward_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
const array& v,
bool do_causal,
const array& o,
const array& d_o,
const array& stats,
array& d_q,
array& d_k,
array& d_v) {
auto dtype = fe::DataType_t::HALF;
if (q.dtype() == bfloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
fe::graph::Graph graph;
graph.set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O"));
auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O"));
auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS"));
set_tensor_attrs(q_, Q, q);
set_tensor_attrs(k_, K, k);
set_tensor_attrs(v_, V, v);
set_tensor_attrs(o_, O, o);
set_tensor_attrs(d_o_, D_O, d_o);
set_tensor_attrs(stats_, STATS, stats);
stats_->set_data_type(fe::DataType_t::FLOAT);
auto scale = graph.tensor(fe::graph::Tensor_attributes()
.set_name("Scale")
.set_uid(SCALE)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal);
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
d_q_->set_output(true);
d_k_->set_output(true);
d_v_->set_output(true);
set_tensor_attrs(d_q_, D_Q, d_q);
set_tensor_attrs(d_k_, D_K, d_k);
set_tensor_attrs(d_v_, D_V, d_v);
CHECK_CUDNN_FE_ERROR(graph.validate());
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
return graph;
}
void execute_graph(
cu::CommandEncoder& encoder,
cudnnHandle_t handle,
fe::graph::Graph& graph,
std::unordered_map<int64_t, void*>& variant_pack) {
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
cudnnSetStream(handle, encoder.stream());
CudaGraph cuda_graph(encoder.device());
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
handle, variant_pack, workspace_ptr, cuda_graph));
encoder.add_graph_node(cuda_graph);
}
} // namespace
bool supports_sdpa_cudnn(
@@ -170,7 +310,7 @@ bool supports_sdpa_cudnn(
}
}
// Only use cuDNN for prefilling.
// Only use cuDNN for prefilling and training.
if (q.shape(2) != k.shape(2)) {
return false;
}
@@ -191,9 +331,13 @@ void sdpa_cudnn(
const array& v,
float scale,
array& o,
array& stats,
bool do_causal,
bool output_logsumexp,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
// TODO: Make O use same memory layout with Q.
o.set_data(cu::malloc_async(o.nbytes(), encoder));
@@ -203,28 +347,19 @@ void sdpa_cudnn(
encoder.set_input_array(v);
encoder.set_output_array(o);
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (output_logsumexp) {
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
encoder.set_output_array(stats);
}
// Search cache.
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(q.dtype()),
vector_key<QKV_NDIM>(q.shape()),
vector_key<QKV_NDIM>(k.shape()),
vector_key<QKV_NDIM>(v.shape()),
vector_key<QKV_NDIM>(q.strides()),
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
};
auto cache_key =
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
it =
sdpa_cache()
.emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o))
.first;
auto graph = build_sdpa_graph(
handle, q, k, v, do_causal, output_logsumexp, o, stats);
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
@@ -234,23 +369,67 @@ void sdpa_cudnn(
{V, const_cast<void*>(gpu_ptr<void>(v))},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
CudaGraph cuda_graph(encoder.device());
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
handle, variant_pack, workspace_ptr, cuda_graph));
encoder.add_graph_node(cuda_graph);
execute_graph(encoder, handle, graph, variant_pack);
}
void sdpa_backward_cudnn(
const array& q,
const array& k,
const array& v,
float scale,
const array& o,
const array& stats,
bool do_causal,
const array& d_o,
array& d_q,
array& d_k,
array& d_v,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_input_array(o);
encoder.set_input_array(stats);
encoder.set_input_array(d_o);
encoder.set_output_array(d_q);
encoder.set_output_array(d_k);
encoder.set_output_array(d_v);
// Search cache.
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, const_cast<void*>(gpu_ptr<void>(q))},
{K, const_cast<void*>(gpu_ptr<void>(k))},
{V, const_cast<void*>(gpu_ptr<void>(v))},
{SCALE, &scale},
{O, const_cast<void*>(gpu_ptr<void>(o))},
{STATS, const_cast<void*>(gpu_ptr<void>(stats))},
{D_O, const_cast<void*>(gpu_ptr<void>(d_o))},
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
execute_graph(encoder, handle, graph, variant_pack);
}
// Defined in scaled_dot_product_attention.cu file.
@@ -260,7 +439,8 @@ bool supports_sdpa_vector(
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal);
bool do_causal,
bool output_logsumexp);
void sdpa_vector(
const array& q,
const array& k,
@@ -280,21 +460,21 @@ bool ScaledDotProductAttention::use_fallback(
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) &&
return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
@@ -302,20 +482,56 @@ void ScaledDotProductAttention::eval_gpu(
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
auto& out = outputs[0];
auto& stats = outputs[1];
bool has_mask = inputs.size() - has_sinks_ > 3;
bool has_arr_mask = has_mask && !do_causal_;
if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) {
if (supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
if (has_sinks_) {
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
} else {
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
}
} else {
sdpa_cudnn(q, k, v, scale_, out, do_causal_, s);
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
}
}
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
// The frontend adds a padding mask when sequence length is not a multiple of
// tile size.
if (q.shape(2) % 128 != 0) {
return true;
}
return s.device == Device::cpu;
}
void ScaledDotProductAttentionVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu");
auto& s = stream();
assert(inputs.size() == 6);
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
array o = prepare_sdpa_input(inputs[3], s);
array stats = prepare_sdpa_input(inputs[4], s);
array d_o = prepare_sdpa_input(inputs[5], s);
assert(outputs.size() == 3);
auto& d_q = outputs[0];
auto& d_k = outputs[1];
auto& d_v = outputs[2];
sdpa_backward_cudnn(
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
}
} // namespace fast
} // namespace mlx::core

View File

@@ -664,7 +664,12 @@ bool supports_sdpa_vector(
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal) {
bool do_causal,
bool output_logsumexp) {
if (output_logsumexp) {
return false;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);

View File

@@ -7,7 +7,6 @@
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
@@ -379,8 +378,15 @@ bool ScaledDotProductAttention::use_fallback(
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
if (detail::in_grad_tracing()) {
if (is_training) {
// It's faster for training on Metal to use the unfused SDPA for both
// forward and backward.
return true;
}
if (output_logsumexp) {
return true;
}
if (s.device == Device::cpu) {
@@ -414,14 +420,14 @@ bool ScaledDotProductAttention::use_fallback(
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
auto& o = outputs[0];
std::vector<array> copies;
@@ -553,4 +559,14 @@ void ScaledDotProductAttention::eval_gpu(
d.add_temporaries(std::move(copies), s.index);
}
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
return true;
}
void ScaledDotProductAttentionVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -30,6 +30,14 @@ bool fast::ScaledDotProductAttention::use_fallback(
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
return true;
}
bool fast::ScaledDotProductAttentionVJP::use_fallback(
const array& q,
Stream s) {
return true;
}
@@ -153,7 +161,8 @@ NO_GPU_MULTI(LayerNormVJP)
NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(ScaledDotProductAttention)
NO_GPU_MULTI(ScaledDotProductAttentionVJP)
NO_GPU_MULTI(ConvertFP8)
NO_GPU_MULTI(Quantize)
NO_GPU_MULTI(CustomKernel)

View File

@@ -6,6 +6,7 @@
#include "mlx/fast_primitives.h"
#include "mlx/ops.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core::fast {
@@ -784,22 +785,88 @@ array scaled_dot_product_attention(
inputs.push_back(astype(*sinks, final_type, stream));
}
bool is_training = detail::in_grad_tracing();
bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback(q, stream);
bool output_logsumexp = is_training && has_fast_vjp;
if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal, has_sinks),
std::move(inputs));
q,
k,
v,
has_mask,
has_arr_mask,
do_causal,
is_training,
output_logsumexp,
stream)) {
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
auto primitive = std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
if (output_logsumexp) {
return array::make_arrays(
{std::move(out_shape), Shape{q.shape(0), q.shape(1), q.shape(2), 1}},
{final_type, float32},
primitive,
std::move(inputs))[0];
} else {
return array(
std::move(out_shape), final_type, primitive, std::move(inputs));
}
}
return fallback(std::move(inputs))[0];
}
std::vector<array> ScaledDotProductAttention::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() >= 3);
assert(cotangents.size() == outputs.size());
auto s = stream();
if (ScaledDotProductAttentionVJP::use_fallback(primals[0], s)) {
assert(outputs.size() == 1);
return Custom::vjp(primals, cotangents, argnums, outputs);
}
auto fallback = [sdpa = fallback_, s](const std::vector<array>& inputs) {
std::vector<array> primals(inputs.begin(), std::prev(inputs.end()));
auto [_, vjps] = mlx::core::vjp(sdpa, primals, {inputs.back()});
return vjps;
};
std::vector<Shape> shapes;
std::vector<Dtype> dtypes;
for (int i = 0; i < primals.size(); ++i) {
shapes.push_back(primals[i].shape());
dtypes.push_back(primals[i].dtype());
}
auto primitive = std::make_shared<ScaledDotProductAttentionVJP>(
s, fallback, scale_, do_causal_, has_sinks_);
std::vector<array> inputs = primals;
inputs.push_back(outputs[0]);
inputs.push_back(outputs[1]);
inputs.push_back(cotangents[0]);
auto vjps = array::make_arrays(std::move(shapes), dtypes, primitive, inputs);
std::vector<array> returned_vjps;
for (int arg : argnums) {
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;
}
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
has_sinks_ == a_other.has_sinks_ &&
output_logsumexp_ == a_other.output_logsumexp_;
}
bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttentionVJP& a_other =
static_cast<const ScaledDotProductAttentionVJP&>(other);
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
has_sinks_ == a_other.has_sinks_;
}

View File

@@ -15,7 +15,7 @@ class Custom : public Primitive {
explicit Custom(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback) {}
: Primitive(stream), fallback_(std::move(fallback)) {}
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
@@ -32,7 +32,7 @@ class Custom : public Primitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
private:
protected:
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
@@ -42,7 +42,7 @@ class RMSNorm : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
: Custom(stream, std::move(fallback)), eps_(eps) {}
static bool use_fallback(Stream stream);
@@ -77,7 +77,7 @@ class RMSNormVJP : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
: Custom(stream, std::move(fallback)), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -102,7 +102,7 @@ class LayerNorm : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
: Custom(stream, std::move(fallback)), eps_(eps) {}
static bool use_fallback(Stream s);
@@ -136,7 +136,7 @@ class LayerNormVJP : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps) {}
: Custom(stream, std::move(fallback)), eps_(eps) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -165,7 +165,7 @@ class RoPE : public Custom {
float base,
float scale,
bool forward)
: Custom(stream, fallback),
: Custom(stream, std::move(fallback)),
dims_(dims),
traditional_(traditional),
base_(base),
@@ -205,16 +205,18 @@ class RoPE : public Custom {
class ScaledDotProductAttention : public Custom {
public:
explicit ScaledDotProductAttention(
ScaledDotProductAttention(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float scale,
bool do_causal,
bool has_sinks)
: Custom(stream, fallback),
bool has_sinks,
bool output_logsumexp)
: Custom(stream, std::move(fallback)),
scale_(scale),
do_causal_(do_causal),
has_sinks_(has_sinks) {}
has_sinks_(has_sinks),
output_logsumexp_(output_logsumexp) {}
static bool use_fallback(
const array& q,
@@ -223,6 +225,8 @@ class ScaledDotProductAttention : public Custom {
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@@ -231,15 +235,55 @@ class ScaledDotProductAttention : public Custom {
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
eval_gpu(inputs, outputs[0]);
}
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override;
DEFINE_NAME(ScaledDotProductAttention);
DEFINE_INPUT_OUTPUT_SHAPE()
auto state() const {
return std::make_tuple(
nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_);
}
private:
float scale_;
bool do_causal_;
bool has_sinks_;
bool output_logsumexp_;
};
class ScaledDotProductAttentionVJP : public Custom {
public:
ScaledDotProductAttentionVJP(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float scale,
bool do_causal,
bool has_sinks)
: Custom(stream, std::move(fallback)),
scale_(scale),
do_causal_(do_causal),
has_sinks_(has_sinks) {}
static bool use_fallback(const array& q, Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_NAME(ScaledDotProductAttentionVJP);
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
}
@@ -288,7 +332,7 @@ class Quantize : public Custom {
int bits,
QuantizationMode mode,
bool dequantize)
: Custom(stream, fallback),
: Custom(stream, std::move(fallback)),
group_size_(group_size),
bits_(bits),
mode_(mode),

View File

@@ -738,6 +738,39 @@ class TestSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
def test_sdpa_grad(self):
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
f4 = lambda q, k, v: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
).sum()
# High tolerance due to cuDNN SDPA kernel requiring tf32.
tolerance = {"rtol": 1e-2, "atol": 1e-2}
for N_q in (8, 32):
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
cotan = mx.ones_like(q)
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
g1 = mx.grad(f3)(q, k, v)
g2 = mx.grad(f4)(q, k, v)
self.assertTrue(mx.allclose(g1, g2, **tolerance))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)