mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
This commit is contained in:
@@ -107,6 +107,8 @@ struct SDPACacheKey {
|
|||||||
std::array<int64_t, QKV_NDIM> k_strides;
|
std::array<int64_t, QKV_NDIM> k_strides;
|
||||||
std::array<int64_t, QKV_NDIM> v_strides;
|
std::array<int64_t, QKV_NDIM> v_strides;
|
||||||
bool do_causal;
|
bool do_causal;
|
||||||
|
std::array<int, QKV_NDIM> mask_shape;
|
||||||
|
std::array<int64_t, QKV_NDIM> mask_strides;
|
||||||
bool output_logsumexp;
|
bool output_logsumexp;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -116,6 +118,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
|||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
bool output_logsumexp = true) {
|
bool output_logsumexp = true) {
|
||||||
BytesKey<SDPACacheKey> cache_key;
|
BytesKey<SDPACacheKey> cache_key;
|
||||||
cache_key.pod = {
|
cache_key.pod = {
|
||||||
@@ -128,8 +131,14 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
|||||||
vector_key<QKV_NDIM>(k.strides()),
|
vector_key<QKV_NDIM>(k.strides()),
|
||||||
vector_key<QKV_NDIM>(v.strides()),
|
vector_key<QKV_NDIM>(v.strides()),
|
||||||
do_causal,
|
do_causal,
|
||||||
|
{},
|
||||||
|
{},
|
||||||
output_logsumexp,
|
output_logsumexp,
|
||||||
};
|
};
|
||||||
|
if (mask_arr) {
|
||||||
|
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
|
||||||
|
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
|
||||||
|
}
|
||||||
return cache_key;
|
return cache_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,6 +159,7 @@ enum UIDS {
|
|||||||
K,
|
K,
|
||||||
V,
|
V,
|
||||||
SCALE,
|
SCALE,
|
||||||
|
BIAS,
|
||||||
O,
|
O,
|
||||||
STATS,
|
STATS,
|
||||||
// Backward graph:
|
// Backward graph:
|
||||||
@@ -165,6 +175,7 @@ fe::graph::Graph build_sdpa_graph(
|
|||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
bool output_logsumexp,
|
bool output_logsumexp,
|
||||||
const array& o,
|
const array& o,
|
||||||
const array& stats) {
|
const array& stats) {
|
||||||
@@ -198,6 +209,11 @@ fe::graph::Graph build_sdpa_graph(
|
|||||||
.set_attn_scale(scale)
|
.set_attn_scale(scale)
|
||||||
.set_causal_mask(do_causal)
|
.set_causal_mask(do_causal)
|
||||||
.set_generate_stats(output_logsumexp);
|
.set_generate_stats(output_logsumexp);
|
||||||
|
if (mask_arr) {
|
||||||
|
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||||
|
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||||
|
options.set_bias(bias_);
|
||||||
|
}
|
||||||
|
|
||||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||||
o_->set_output(true);
|
o_->set_output(true);
|
||||||
@@ -224,6 +240,7 @@ fe::graph::Graph build_sdpa_backward_graph(
|
|||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
const array& o,
|
const array& o,
|
||||||
const array& d_o,
|
const array& d_o,
|
||||||
const array& stats,
|
const array& stats,
|
||||||
@@ -266,6 +283,11 @@ fe::graph::Graph build_sdpa_backward_graph(
|
|||||||
.set_name("sdpa_backward_cudnn")
|
.set_name("sdpa_backward_cudnn")
|
||||||
.set_attn_scale(scale)
|
.set_attn_scale(scale)
|
||||||
.set_causal_mask(do_causal);
|
.set_causal_mask(do_causal);
|
||||||
|
if (mask_arr) {
|
||||||
|
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||||
|
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||||
|
options.set_bias(bias_);
|
||||||
|
}
|
||||||
|
|
||||||
auto [d_q_, d_k_, d_v_] =
|
auto [d_q_, d_k_, d_v_] =
|
||||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||||
@@ -318,8 +340,6 @@ bool supports_sdpa_cudnn(
|
|||||||
const array& q,
|
const array& q,
|
||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool has_mask,
|
|
||||||
bool do_causal,
|
|
||||||
Stream s) {
|
Stream s) {
|
||||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||||
if (!enabled) {
|
if (!enabled) {
|
||||||
@@ -331,17 +351,6 @@ bool supports_sdpa_cudnn(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_mask) {
|
|
||||||
// TODO: Support array masks.
|
|
||||||
if (!do_causal) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// FIXME: Causal mask generates wrong results when L_Q != L_K.
|
|
||||||
if (q.shape(2) != k.shape(2)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only use cuDNN for prefilling and training.
|
// Only use cuDNN for prefilling and training.
|
||||||
if (q.shape(2) != k.shape(2)) {
|
if (q.shape(2) != k.shape(2)) {
|
||||||
return false;
|
return false;
|
||||||
@@ -365,6 +374,7 @@ void sdpa_cudnn(
|
|||||||
array& o,
|
array& o,
|
||||||
array& stats,
|
array& stats,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
bool output_logsumexp,
|
bool output_logsumexp,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
@@ -376,19 +386,21 @@ void sdpa_cudnn(
|
|||||||
encoder.set_input_array(k);
|
encoder.set_input_array(k);
|
||||||
encoder.set_input_array(v);
|
encoder.set_input_array(v);
|
||||||
encoder.set_output_array(o);
|
encoder.set_output_array(o);
|
||||||
|
if (mask_arr) {
|
||||||
|
encoder.set_input_array(*mask_arr);
|
||||||
|
}
|
||||||
if (output_logsumexp) {
|
if (output_logsumexp) {
|
||||||
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
||||||
encoder.set_output_array(stats);
|
encoder.set_output_array(stats);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search cache.
|
// Search cache.
|
||||||
auto cache_key =
|
auto cache_key = build_sdpa_cache_key(
|
||||||
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
|
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
|
||||||
auto it = sdpa_cache().find(cache_key);
|
auto it = sdpa_cache().find(cache_key);
|
||||||
if (it == sdpa_cache().end()) {
|
if (it == sdpa_cache().end()) {
|
||||||
auto graph = build_sdpa_graph(
|
auto graph = build_sdpa_graph(
|
||||||
handle, q, k, v, do_causal, output_logsumexp, o, stats);
|
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
|
||||||
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
||||||
}
|
}
|
||||||
auto& graph = it->second;
|
auto& graph = it->second;
|
||||||
@@ -399,6 +411,9 @@ void sdpa_cudnn(
|
|||||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||||
{SCALE, &scale},
|
{SCALE, &scale},
|
||||||
{O, gpu_ptr<void>(o)}};
|
{O, gpu_ptr<void>(o)}};
|
||||||
|
if (mask_arr) {
|
||||||
|
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||||
|
}
|
||||||
if (output_logsumexp) {
|
if (output_logsumexp) {
|
||||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||||
}
|
}
|
||||||
@@ -414,6 +429,7 @@ void sdpa_backward_cudnn(
|
|||||||
const array& o,
|
const array& o,
|
||||||
const array& stats,
|
const array& stats,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
const array& d_o,
|
const array& d_o,
|
||||||
array& d_q,
|
array& d_q,
|
||||||
array& d_k,
|
array& d_k,
|
||||||
@@ -435,13 +451,16 @@ void sdpa_backward_cudnn(
|
|||||||
encoder.set_output_array(d_q);
|
encoder.set_output_array(d_q);
|
||||||
encoder.set_output_array(d_k);
|
encoder.set_output_array(d_k);
|
||||||
encoder.set_output_array(d_v);
|
encoder.set_output_array(d_v);
|
||||||
|
if (mask_arr) {
|
||||||
|
encoder.set_input_array(*mask_arr);
|
||||||
|
}
|
||||||
|
|
||||||
// Search cache.
|
// Search cache.
|
||||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
|
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
|
||||||
auto it = sdpa_backward_cache().find(cache_key);
|
auto it = sdpa_backward_cache().find(cache_key);
|
||||||
if (it == sdpa_backward_cache().end()) {
|
if (it == sdpa_backward_cache().end()) {
|
||||||
auto graph = build_sdpa_backward_graph(
|
auto graph = build_sdpa_backward_graph(
|
||||||
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
|
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
|
||||||
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
||||||
}
|
}
|
||||||
auto& graph = it->second;
|
auto& graph = it->second;
|
||||||
@@ -457,6 +476,9 @@ void sdpa_backward_cudnn(
|
|||||||
{D_Q, gpu_ptr<void>(d_q)},
|
{D_Q, gpu_ptr<void>(d_q)},
|
||||||
{D_K, gpu_ptr<void>(d_k)},
|
{D_K, gpu_ptr<void>(d_k)},
|
||||||
{D_V, gpu_ptr<void>(d_v)}};
|
{D_V, gpu_ptr<void>(d_v)}};
|
||||||
|
if (mask_arr) {
|
||||||
|
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||||
|
}
|
||||||
|
|
||||||
execute_graph(encoder, handle, graph, variant_pack);
|
execute_graph(encoder, handle, graph, variant_pack);
|
||||||
}
|
}
|
||||||
@@ -498,7 +520,11 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
|
|
||||||
return !supports_sdpa_vector(
|
return !supports_sdpa_vector(
|
||||||
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
!supports_sdpa_cudnn(q, k, v, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
@@ -516,6 +542,11 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||||
bool has_arr_mask = has_mask && !do_causal_;
|
bool has_arr_mask = has_mask && !do_causal_;
|
||||||
|
|
||||||
|
std::optional<array> mask_arr;
|
||||||
|
if (has_arr_mask) {
|
||||||
|
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||||
|
}
|
||||||
|
|
||||||
if (supports_sdpa_vector(
|
if (supports_sdpa_vector(
|
||||||
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
||||||
if (has_sinks_) {
|
if (has_sinks_) {
|
||||||
@@ -524,7 +555,17 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
|
sdpa_cudnn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale_,
|
||||||
|
out,
|
||||||
|
stats,
|
||||||
|
do_causal_,
|
||||||
|
mask_arr,
|
||||||
|
output_logsumexp_,
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -544,13 +585,21 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
|||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
|
|
||||||
assert(inputs.size() == 6);
|
assert(inputs.size() >= 6);
|
||||||
|
int primals_size = inputs.size() - 3;
|
||||||
|
bool has_arr_mask = primals_size > 3 + has_sinks_;
|
||||||
|
|
||||||
array q = prepare_sdpa_input(inputs[0], s);
|
array q = prepare_sdpa_input(inputs[0], s);
|
||||||
array k = prepare_sdpa_input(inputs[1], s);
|
array k = prepare_sdpa_input(inputs[1], s);
|
||||||
array v = prepare_sdpa_input(inputs[2], s);
|
array v = prepare_sdpa_input(inputs[2], s);
|
||||||
array o = prepare_sdpa_input(inputs[3], s);
|
array o = prepare_sdpa_input(inputs[primals_size], s);
|
||||||
array stats = prepare_sdpa_input(inputs[4], s);
|
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
|
||||||
array d_o = prepare_sdpa_input(inputs[5], s);
|
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
|
||||||
|
|
||||||
|
std::optional<array> mask_arr;
|
||||||
|
if (has_arr_mask) {
|
||||||
|
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||||
|
}
|
||||||
|
|
||||||
assert(outputs.size() == 3);
|
assert(outputs.size() == 3);
|
||||||
auto& d_q = outputs[0];
|
auto& d_q = outputs[0];
|
||||||
@@ -558,7 +607,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
|||||||
auto& d_v = outputs[2];
|
auto& d_v = outputs[2];
|
||||||
|
|
||||||
sdpa_backward_cudnn(
|
sdpa_backward_cudnn(
|
||||||
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
|
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fast
|
} // namespace fast
|
||||||
|
|||||||
@@ -569,6 +569,10 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
return !(supports_sdpa_full || supports_sdpa_vector);
|
return !(supports_sdpa_full || supports_sdpa_vector);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fast::ScaledDotProductAttention::supports_bool_mask() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool fast::ScaledDotProductAttentionVJP::use_fallback(
|
bool fast::ScaledDotProductAttentionVJP::use_fallback(
|
||||||
const array& q,
|
const array& q,
|
||||||
Stream s) {
|
Stream s) {
|
||||||
|
|||||||
11
mlx/fast.cpp
11
mlx/fast.cpp
@@ -800,6 +800,15 @@ array scaled_dot_product_attention(
|
|||||||
is_training,
|
is_training,
|
||||||
output_logsumexp,
|
output_logsumexp,
|
||||||
stream)) {
|
stream)) {
|
||||||
|
if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {
|
||||||
|
// Convert bool mask to additive mask.
|
||||||
|
float inf = std::numeric_limits<float>::infinity();
|
||||||
|
array& mask = inputs[3];
|
||||||
|
mask = where(
|
||||||
|
mask,
|
||||||
|
full_like(mask, 0, final_type, s),
|
||||||
|
full_like(mask, -inf, final_type, s));
|
||||||
|
}
|
||||||
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||||
auto primitive = std::make_shared<ScaledDotProductAttention>(
|
auto primitive = std::make_shared<ScaledDotProductAttention>(
|
||||||
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
|
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
|
||||||
@@ -839,7 +848,7 @@ std::vector<array> ScaledDotProductAttention::vjp(
|
|||||||
|
|
||||||
std::vector<Shape> shapes;
|
std::vector<Shape> shapes;
|
||||||
std::vector<Dtype> dtypes;
|
std::vector<Dtype> dtypes;
|
||||||
for (int i = 0; i < primals.size(); ++i) {
|
for (int i = 0; i < /* outputs size */ 3; ++i) {
|
||||||
shapes.push_back(primals[i].shape());
|
shapes.push_back(primals[i].shape());
|
||||||
dtypes.push_back(primals[i].dtype());
|
dtypes.push_back(primals[i].dtype());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -228,6 +228,7 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
bool is_training,
|
bool is_training,
|
||||||
bool output_logsumexp,
|
bool output_logsumexp,
|
||||||
Stream s);
|
Stream s);
|
||||||
|
static bool supports_bool_mask();
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
|
|||||||
@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||||
|
|
||||||
def test_sdpa_grad(self):
|
def test_sdpa_grad(self):
|
||||||
B, N_kv, T, D = (2, 8, 128, 64)
|
|
||||||
scale = D**-0.5
|
|
||||||
|
|
||||||
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
|
|
||||||
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
|
||||||
|
|
||||||
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
|
|
||||||
f4 = lambda q, k, v: (
|
|
||||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
|
||||||
).sum()
|
|
||||||
|
|
||||||
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
||||||
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
||||||
|
|
||||||
|
def test_vjp(slow, fast, primals):
|
||||||
|
cotan = mx.ones_like(primals[0])
|
||||||
|
o1, vjp1 = mx.vjp(slow, primals, [cotan])
|
||||||
|
o2, vjp2 = mx.vjp(fast, primals, [cotan])
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||||
|
for i in range(3):
|
||||||
|
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||||
|
|
||||||
|
def test_grad(slow, fast, args):
|
||||||
|
g1 = mx.grad(slow)(*args)
|
||||||
|
g2 = mx.grad(fast)(*args)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||||
|
|
||||||
|
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
).sum()
|
||||||
|
loss_mask_fast = lambda q, k, v, mask: (
|
||||||
|
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
).sum()
|
||||||
|
|
||||||
|
B, N_kv, T, D = (2, 8, 128, 64)
|
||||||
|
scale = D**-0.5
|
||||||
|
|
||||||
for N_q in (8, 32):
|
for N_q in (8, 32):
|
||||||
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
||||||
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||||
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||||
|
|
||||||
cotan = mx.ones_like(q)
|
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
|
||||||
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
|
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
|
||||||
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
for mask in (mask_additive, mask_bool):
|
||||||
for i in range(3):
|
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
|
||||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
|
||||||
|
|
||||||
g1 = mx.grad(f3)(q, k, v)
|
for mask in (None, "causal"):
|
||||||
g2 = mx.grad(f4)(q, k, v)
|
sdpa_slow = lambda q, k, v: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
loss_slow = lambda q, k, v: mlx_ref_attn(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
).sum()
|
||||||
|
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=scale, mask=mask
|
||||||
|
).sum()
|
||||||
|
test_grad(loss_slow, loss_fast, [q, k, v])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user