mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] cuDNN backward attention (#2762)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
This commit is contained in:
@@ -168,7 +168,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.14.0
|
||||
GIT_TAG v1.16.0
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
@@ -26,6 +25,11 @@ namespace {
|
||||
|
||||
std::vector<int64_t> normalized_strides(const array& x) {
|
||||
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||
if (std::all_of(
|
||||
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
|
||||
strides.back() = 1;
|
||||
return strides;
|
||||
}
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return strides;
|
||||
}
|
||||
@@ -71,11 +75,41 @@ struct SDPACacheKey {
|
||||
std::array<int64_t, QKV_NDIM> k_strides;
|
||||
std::array<int64_t, QKV_NDIM> v_strides;
|
||||
bool do_causal;
|
||||
bool output_logsumexp;
|
||||
};
|
||||
|
||||
inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
bool output_logsumexp = true) {
|
||||
BytesKey<SDPACacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(q.dtype()),
|
||||
vector_key<QKV_NDIM>(q.shape()),
|
||||
vector_key<QKV_NDIM>(k.shape()),
|
||||
vector_key<QKV_NDIM>(v.shape()),
|
||||
vector_key<QKV_NDIM>(q.strides()),
|
||||
vector_key<QKV_NDIM>(k.strides()),
|
||||
vector_key<QKV_NDIM>(v.strides()),
|
||||
do_causal,
|
||||
output_logsumexp,
|
||||
};
|
||||
return cache_key;
|
||||
}
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128);
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
|
||||
return cache;
|
||||
}
|
||||
|
||||
auto& sdpa_backward_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
|
||||
return cache;
|
||||
}
|
||||
|
||||
@@ -85,6 +119,12 @@ enum UIDS {
|
||||
V,
|
||||
SCALE,
|
||||
O,
|
||||
STATS,
|
||||
// Backward graph:
|
||||
D_Q,
|
||||
D_K,
|
||||
D_V,
|
||||
D_O,
|
||||
};
|
||||
|
||||
fe::graph::Graph build_sdpa_graph(
|
||||
@@ -93,7 +133,9 @@ fe::graph::Graph build_sdpa_graph(
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const array& o) {
|
||||
bool output_logsumexp,
|
||||
const array& o,
|
||||
const array& stats) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
@@ -119,15 +161,19 @@ fe::graph::Graph build_sdpa_graph(
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
|
||||
auto sdpa_options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(false);
|
||||
auto options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(output_logsumexp);
|
||||
|
||||
auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options);
|
||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||
o_->set_output(true);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
if (output_logsumexp) {
|
||||
stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT);
|
||||
set_tensor_attrs(stats_, STATS, stats);
|
||||
}
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
@@ -140,6 +186,100 @@ fe::graph::Graph build_sdpa_graph(
|
||||
return graph;
|
||||
}
|
||||
|
||||
fe::graph::Graph build_sdpa_backward_graph(
|
||||
cudnnHandle_t handle,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const array& o,
|
||||
const array& d_o,
|
||||
const array& stats,
|
||||
array& d_q,
|
||||
array& d_k,
|
||||
array& d_v) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
|
||||
fe::graph::Graph graph;
|
||||
graph.set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
||||
auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O"));
|
||||
auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O"));
|
||||
auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS"));
|
||||
set_tensor_attrs(q_, Q, q);
|
||||
set_tensor_attrs(k_, K, k);
|
||||
set_tensor_attrs(v_, V, v);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
set_tensor_attrs(d_o_, D_O, d_o);
|
||||
set_tensor_attrs(stats_, STATS, stats);
|
||||
stats_->set_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Scale")
|
||||
.set_uid(SCALE)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
|
||||
auto options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("sdpa_backward_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal);
|
||||
|
||||
auto [d_q_, d_k_, d_v_] =
|
||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||
d_q_->set_output(true);
|
||||
d_k_->set_output(true);
|
||||
d_v_->set_output(true);
|
||||
set_tensor_attrs(d_q_, D_Q, d_q);
|
||||
set_tensor_attrs(d_k_, D_K, d_k);
|
||||
set_tensor_attrs(d_v_, D_V, d_v);
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
||||
graph.select_behavior_notes(
|
||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
void execute_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnHandle_t handle,
|
||||
fe::graph::Graph& graph,
|
||||
std::unordered_map<int64_t, void*>& variant_pack) {
|
||||
int64_t workspace_size = 0;
|
||||
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder),
|
||||
{static_cast<int>(workspace_size)},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
CudaGraph cuda_graph(encoder.device());
|
||||
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
|
||||
handle, variant_pack, workspace_ptr, cuda_graph));
|
||||
encoder.add_graph_node(cuda_graph);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool supports_sdpa_cudnn(
|
||||
@@ -170,7 +310,7 @@ bool supports_sdpa_cudnn(
|
||||
}
|
||||
}
|
||||
|
||||
// Only use cuDNN for prefilling.
|
||||
// Only use cuDNN for prefilling and training.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
@@ -191,9 +331,13 @@ void sdpa_cudnn(
|
||||
const array& v,
|
||||
float scale,
|
||||
array& o,
|
||||
array& stats,
|
||||
bool do_causal,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
// TODO: Make O use same memory layout with Q.
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
@@ -203,28 +347,19 @@ void sdpa_cudnn(
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
if (output_logsumexp) {
|
||||
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
||||
encoder.set_output_array(stats);
|
||||
}
|
||||
|
||||
// Search cache.
|
||||
BytesKey<SDPACacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(q.dtype()),
|
||||
vector_key<QKV_NDIM>(q.shape()),
|
||||
vector_key<QKV_NDIM>(k.shape()),
|
||||
vector_key<QKV_NDIM>(v.shape()),
|
||||
vector_key<QKV_NDIM>(q.strides()),
|
||||
vector_key<QKV_NDIM>(k.strides()),
|
||||
vector_key<QKV_NDIM>(v.strides()),
|
||||
do_causal,
|
||||
};
|
||||
auto cache_key =
|
||||
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
|
||||
auto it = sdpa_cache().find(cache_key);
|
||||
if (it == sdpa_cache().end()) {
|
||||
it =
|
||||
sdpa_cache()
|
||||
.emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o))
|
||||
.first;
|
||||
auto graph = build_sdpa_graph(
|
||||
handle, q, k, v, do_causal, output_logsumexp, o, stats);
|
||||
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
|
||||
@@ -234,23 +369,67 @@ void sdpa_cudnn(
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, gpu_ptr<void>(o)}};
|
||||
|
||||
int64_t workspace_size = 0;
|
||||
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder),
|
||||
{static_cast<int>(workspace_size)},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
if (output_logsumexp) {
|
||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||
}
|
||||
|
||||
CudaGraph cuda_graph(encoder.device());
|
||||
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
|
||||
handle, variant_pack, workspace_ptr, cuda_graph));
|
||||
encoder.add_graph_node(cuda_graph);
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
}
|
||||
|
||||
void sdpa_backward_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
const array& o,
|
||||
const array& stats,
|
||||
bool do_causal,
|
||||
const array& d_o,
|
||||
array& d_q,
|
||||
array& d_k,
|
||||
array& d_v,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
|
||||
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
|
||||
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_input_array(o);
|
||||
encoder.set_input_array(stats);
|
||||
encoder.set_input_array(d_o);
|
||||
encoder.set_output_array(d_q);
|
||||
encoder.set_output_array(d_k);
|
||||
encoder.set_output_array(d_v);
|
||||
|
||||
// Search cache.
|
||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
|
||||
auto it = sdpa_backward_cache().find(cache_key);
|
||||
if (it == sdpa_backward_cache().end()) {
|
||||
auto graph = build_sdpa_backward_graph(
|
||||
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
|
||||
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, const_cast<void*>(gpu_ptr<void>(o))},
|
||||
{STATS, const_cast<void*>(gpu_ptr<void>(stats))},
|
||||
{D_O, const_cast<void*>(gpu_ptr<void>(d_o))},
|
||||
{D_Q, gpu_ptr<void>(d_q)},
|
||||
{D_K, gpu_ptr<void>(d_k)},
|
||||
{D_V, gpu_ptr<void>(d_v)}};
|
||||
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
}
|
||||
|
||||
// Defined in scaled_dot_product_attention.cu file.
|
||||
@@ -260,7 +439,8 @@ bool supports_sdpa_vector(
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal);
|
||||
bool do_causal,
|
||||
bool output_logsumexp);
|
||||
void sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
@@ -280,21 +460,21 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
bool is_training,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) &&
|
||||
return !supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
@@ -302,20 +482,56 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
auto& out = outputs[0];
|
||||
auto& stats = outputs[1];
|
||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||
bool has_arr_mask = has_mask && !do_causal_;
|
||||
|
||||
if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) {
|
||||
if (supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
||||
if (has_sinks_) {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
|
||||
} else {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||
}
|
||||
} else {
|
||||
sdpa_cudnn(q, k, v, scale_, out, do_causal_, s);
|
||||
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
|
||||
}
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
|
||||
// The frontend adds a padding mask when sequence length is not a multiple of
|
||||
// tile size.
|
||||
if (q.shape(2) % 128 != 0) {
|
||||
return true;
|
||||
}
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttentionVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
|
||||
assert(inputs.size() == 6);
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
array o = prepare_sdpa_input(inputs[3], s);
|
||||
array stats = prepare_sdpa_input(inputs[4], s);
|
||||
array d_o = prepare_sdpa_input(inputs[5], s);
|
||||
|
||||
assert(outputs.size() == 3);
|
||||
auto& d_q = outputs[0];
|
||||
auto& d_k = outputs[1];
|
||||
auto& d_v = outputs[2];
|
||||
|
||||
sdpa_backward_cudnn(
|
||||
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -664,7 +664,12 @@ bool supports_sdpa_vector(
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal) {
|
||||
bool do_causal,
|
||||
bool output_logsumexp) {
|
||||
if (output_logsumexp) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int value_head_dim = v.shape(-1);
|
||||
const int query_head_dim = q.shape(-1);
|
||||
const int query_sequence_length = q.shape(2);
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@@ -379,8 +378,15 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
bool is_training,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
if (is_training) {
|
||||
// It's faster for training on Metal to use the unfused SDPA for both
|
||||
// forward and backward.
|
||||
return true;
|
||||
}
|
||||
if (output_logsumexp) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
@@ -414,14 +420,14 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
auto& o = outputs[0];
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
@@ -553,4 +559,14 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttentionVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
||||
@@ -30,6 +30,14 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
bool is_training,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttentionVJP::use_fallback(
|
||||
const array& q,
|
||||
Stream s) {
|
||||
return true;
|
||||
}
|
||||
@@ -153,7 +161,8 @@ NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(ScaledDotProductAttentionVJP)
|
||||
NO_GPU_MULTI(ConvertFP8)
|
||||
NO_GPU_MULTI(Quantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
83
mlx/fast.cpp
83
mlx/fast.cpp
@@ -6,6 +6,7 @@
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
@@ -784,22 +785,88 @@ array scaled_dot_product_attention(
|
||||
inputs.push_back(astype(*sinks, final_type, stream));
|
||||
}
|
||||
|
||||
bool is_training = detail::in_grad_tracing();
|
||||
bool has_fast_vjp = !ScaledDotProductAttentionVJP::use_fallback(q, stream);
|
||||
bool output_logsumexp = is_training && has_fast_vjp;
|
||||
if (!ScaledDotProductAttention::use_fallback(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
|
||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, do_causal, has_sinks),
|
||||
std::move(inputs));
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
has_mask,
|
||||
has_arr_mask,
|
||||
do_causal,
|
||||
is_training,
|
||||
output_logsumexp,
|
||||
stream)) {
|
||||
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
auto primitive = std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
|
||||
if (output_logsumexp) {
|
||||
return array::make_arrays(
|
||||
{std::move(out_shape), Shape{q.shape(0), q.shape(1), q.shape(2), 1}},
|
||||
{final_type, float32},
|
||||
primitive,
|
||||
std::move(inputs))[0];
|
||||
} else {
|
||||
return array(
|
||||
std::move(out_shape), final_type, primitive, std::move(inputs));
|
||||
}
|
||||
}
|
||||
return fallback(std::move(inputs))[0];
|
||||
}
|
||||
|
||||
std::vector<array> ScaledDotProductAttention::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() >= 3);
|
||||
assert(cotangents.size() == outputs.size());
|
||||
|
||||
auto s = stream();
|
||||
if (ScaledDotProductAttentionVJP::use_fallback(primals[0], s)) {
|
||||
assert(outputs.size() == 1);
|
||||
return Custom::vjp(primals, cotangents, argnums, outputs);
|
||||
}
|
||||
|
||||
auto fallback = [sdpa = fallback_, s](const std::vector<array>& inputs) {
|
||||
std::vector<array> primals(inputs.begin(), std::prev(inputs.end()));
|
||||
auto [_, vjps] = mlx::core::vjp(sdpa, primals, {inputs.back()});
|
||||
return vjps;
|
||||
};
|
||||
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
shapes.push_back(primals[i].shape());
|
||||
dtypes.push_back(primals[i].dtype());
|
||||
}
|
||||
auto primitive = std::make_shared<ScaledDotProductAttentionVJP>(
|
||||
s, fallback, scale_, do_causal_, has_sinks_);
|
||||
std::vector<array> inputs = primals;
|
||||
inputs.push_back(outputs[0]);
|
||||
inputs.push_back(outputs[1]);
|
||||
inputs.push_back(cotangents[0]);
|
||||
auto vjps = array::make_arrays(std::move(shapes), dtypes, primitive, inputs);
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (int arg : argnums) {
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
return returned_vjps;
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttention& a_other =
|
||||
static_cast<const ScaledDotProductAttention&>(other);
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
|
||||
has_sinks_ == a_other.has_sinks_ &&
|
||||
output_logsumexp_ == a_other.output_logsumexp_;
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttentionVJP::is_equivalent(const Primitive& other) const {
|
||||
const ScaledDotProductAttentionVJP& a_other =
|
||||
static_cast<const ScaledDotProductAttentionVJP&>(other);
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
|
||||
has_sinks_ == a_other.has_sinks_;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ class Custom : public Primitive {
|
||||
explicit Custom(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||
: Primitive(stream), fallback_(fallback) {}
|
||||
: Primitive(stream), fallback_(std::move(fallback)) {}
|
||||
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
@@ -32,7 +32,7 @@ class Custom : public Primitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
};
|
||||
|
||||
@@ -42,7 +42,7 @@ class RMSNorm : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
||||
|
||||
static bool use_fallback(Stream stream);
|
||||
|
||||
@@ -77,7 +77,7 @@ class RMSNormVJP : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@@ -102,7 +102,7 @@ class LayerNorm : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
||||
|
||||
static bool use_fallback(Stream s);
|
||||
|
||||
@@ -136,7 +136,7 @@ class LayerNormVJP : public Custom {
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps) {}
|
||||
: Custom(stream, std::move(fallback)), eps_(eps) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@@ -165,7 +165,7 @@ class RoPE : public Custom {
|
||||
float base,
|
||||
float scale,
|
||||
bool forward)
|
||||
: Custom(stream, fallback),
|
||||
: Custom(stream, std::move(fallback)),
|
||||
dims_(dims),
|
||||
traditional_(traditional),
|
||||
base_(base),
|
||||
@@ -205,16 +205,18 @@ class RoPE : public Custom {
|
||||
|
||||
class ScaledDotProductAttention : public Custom {
|
||||
public:
|
||||
explicit ScaledDotProductAttention(
|
||||
ScaledDotProductAttention(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float scale,
|
||||
bool do_causal,
|
||||
bool has_sinks)
|
||||
: Custom(stream, fallback),
|
||||
bool has_sinks,
|
||||
bool output_logsumexp)
|
||||
: Custom(stream, std::move(fallback)),
|
||||
scale_(scale),
|
||||
do_causal_(do_causal),
|
||||
has_sinks_(has_sinks) {}
|
||||
has_sinks_(has_sinks),
|
||||
output_logsumexp_(output_logsumexp) {}
|
||||
|
||||
static bool use_fallback(
|
||||
const array& q,
|
||||
@@ -223,6 +225,8 @@ class ScaledDotProductAttention : public Custom {
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
bool is_training,
|
||||
bool output_logsumexp,
|
||||
Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -231,15 +235,55 @@ class ScaledDotProductAttention : public Custom {
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
eval_gpu(inputs, outputs[0]);
|
||||
}
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out);
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
DEFINE_NAME(ScaledDotProductAttention);
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
auto state() const {
|
||||
return std::make_tuple(
|
||||
nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_);
|
||||
}
|
||||
|
||||
private:
|
||||
float scale_;
|
||||
bool do_causal_;
|
||||
bool has_sinks_;
|
||||
bool output_logsumexp_;
|
||||
};
|
||||
|
||||
class ScaledDotProductAttentionVJP : public Custom {
|
||||
public:
|
||||
ScaledDotProductAttentionVJP(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float scale,
|
||||
bool do_causal,
|
||||
bool has_sinks)
|
||||
: Custom(stream, std::move(fallback)),
|
||||
scale_(scale),
|
||||
do_causal_(do_causal),
|
||||
has_sinks_(has_sinks) {}
|
||||
|
||||
static bool use_fallback(const array& q, Stream s);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_NAME(ScaledDotProductAttentionVJP);
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
auto state() const {
|
||||
return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_);
|
||||
}
|
||||
@@ -288,7 +332,7 @@ class Quantize : public Custom {
|
||||
int bits,
|
||||
QuantizationMode mode,
|
||||
bool dequantize)
|
||||
: Custom(stream, fallback),
|
||||
: Custom(stream, std::move(fallback)),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
mode_(mode),
|
||||
|
||||
@@ -738,6 +738,39 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||
|
||||
def test_sdpa_grad(self):
|
||||
B, N_kv, T, D = (2, 8, 128, 64)
|
||||
scale = D**-0.5
|
||||
|
||||
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
|
||||
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
|
||||
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
|
||||
f4 = lambda q, k, v: (
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
).sum()
|
||||
|
||||
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
||||
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
||||
|
||||
for N_q in (8, 32):
|
||||
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
||||
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
|
||||
cotan = mx.ones_like(q)
|
||||
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
|
||||
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
|
||||
|
||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||
|
||||
g1 = mx.grad(f3)(q, k, v)
|
||||
g2 = mx.grad(f4)(q, k, v)
|
||||
|
||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
||||
Reference in New Issue
Block a user