mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
This commit is contained in:
@@ -107,6 +107,8 @@ struct SDPACacheKey {
|
||||
std::array<int64_t, QKV_NDIM> k_strides;
|
||||
std::array<int64_t, QKV_NDIM> v_strides;
|
||||
bool do_causal;
|
||||
std::array<int, QKV_NDIM> mask_shape;
|
||||
std::array<int64_t, QKV_NDIM> mask_strides;
|
||||
bool output_logsumexp;
|
||||
};
|
||||
|
||||
@@ -116,6 +118,7 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
bool output_logsumexp = true) {
|
||||
BytesKey<SDPACacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
@@ -128,8 +131,14 @@ inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
vector_key<QKV_NDIM>(k.strides()),
|
||||
vector_key<QKV_NDIM>(v.strides()),
|
||||
do_causal,
|
||||
{},
|
||||
{},
|
||||
output_logsumexp,
|
||||
};
|
||||
if (mask_arr) {
|
||||
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
|
||||
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
|
||||
}
|
||||
return cache_key;
|
||||
}
|
||||
|
||||
@@ -150,6 +159,7 @@ enum UIDS {
|
||||
K,
|
||||
V,
|
||||
SCALE,
|
||||
BIAS,
|
||||
O,
|
||||
STATS,
|
||||
// Backward graph:
|
||||
@@ -165,6 +175,7 @@ fe::graph::Graph build_sdpa_graph(
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
bool output_logsumexp,
|
||||
const array& o,
|
||||
const array& stats) {
|
||||
@@ -198,6 +209,11 @@ fe::graph::Graph build_sdpa_graph(
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(output_logsumexp);
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
options.set_bias(bias_);
|
||||
}
|
||||
|
||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||
o_->set_output(true);
|
||||
@@ -224,6 +240,7 @@ fe::graph::Graph build_sdpa_backward_graph(
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
const array& o,
|
||||
const array& d_o,
|
||||
const array& stats,
|
||||
@@ -266,6 +283,11 @@ fe::graph::Graph build_sdpa_backward_graph(
|
||||
.set_name("sdpa_backward_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal);
|
||||
if (mask_arr) {
|
||||
auto bias_ = graph.tensor(fe::graph::Tensor_attributes().set_name("BIAS"));
|
||||
set_tensor_attrs(bias_, BIAS, *mask_arr);
|
||||
options.set_bias(bias_);
|
||||
}
|
||||
|
||||
auto [d_q_, d_k_, d_v_] =
|
||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||
@@ -318,8 +340,6 @@ bool supports_sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||
if (!enabled) {
|
||||
@@ -331,17 +351,6 @@ bool supports_sdpa_cudnn(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
// TODO: Support array masks.
|
||||
if (!do_causal) {
|
||||
return false;
|
||||
}
|
||||
// FIXME: Causal mask generates wrong results when L_Q != L_K.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Only use cuDNN for prefilling and training.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
@@ -365,6 +374,7 @@ void sdpa_cudnn(
|
||||
array& o,
|
||||
array& stats,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
@@ -376,19 +386,21 @@ void sdpa_cudnn(
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
if (mask_arr) {
|
||||
encoder.set_input_array(*mask_arr);
|
||||
}
|
||||
if (output_logsumexp) {
|
||||
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
||||
encoder.set_output_array(stats);
|
||||
}
|
||||
|
||||
// Search cache.
|
||||
auto cache_key =
|
||||
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
|
||||
auto cache_key = build_sdpa_cache_key(
|
||||
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
|
||||
auto it = sdpa_cache().find(cache_key);
|
||||
if (it == sdpa_cache().end()) {
|
||||
auto graph = build_sdpa_graph(
|
||||
handle, q, k, v, do_causal, output_logsumexp, o, stats);
|
||||
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
|
||||
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
@@ -399,6 +411,9 @@ void sdpa_cudnn(
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, gpu_ptr<void>(o)}};
|
||||
if (mask_arr) {
|
||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||
}
|
||||
if (output_logsumexp) {
|
||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||
}
|
||||
@@ -414,6 +429,7 @@ void sdpa_backward_cudnn(
|
||||
const array& o,
|
||||
const array& stats,
|
||||
bool do_causal,
|
||||
const std::optional<array>& mask_arr,
|
||||
const array& d_o,
|
||||
array& d_q,
|
||||
array& d_k,
|
||||
@@ -435,13 +451,16 @@ void sdpa_backward_cudnn(
|
||||
encoder.set_output_array(d_q);
|
||||
encoder.set_output_array(d_k);
|
||||
encoder.set_output_array(d_v);
|
||||
if (mask_arr) {
|
||||
encoder.set_input_array(*mask_arr);
|
||||
}
|
||||
|
||||
// Search cache.
|
||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
|
||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
|
||||
auto it = sdpa_backward_cache().find(cache_key);
|
||||
if (it == sdpa_backward_cache().end()) {
|
||||
auto graph = build_sdpa_backward_graph(
|
||||
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
|
||||
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
|
||||
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
@@ -457,6 +476,9 @@ void sdpa_backward_cudnn(
|
||||
{D_Q, gpu_ptr<void>(d_q)},
|
||||
{D_K, gpu_ptr<void>(d_k)},
|
||||
{D_V, gpu_ptr<void>(d_v)}};
|
||||
if (mask_arr) {
|
||||
variant_pack[BIAS] = const_cast<void*>(gpu_ptr<void>(*mask_arr));
|
||||
}
|
||||
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
}
|
||||
@@ -498,7 +520,11 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
|
||||
return !supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
||||
!supports_sdpa_cudnn(q, k, v, s);
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||
return false;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
@@ -516,6 +542,11 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||
bool has_arr_mask = has_mask && !do_causal_;
|
||||
|
||||
std::optional<array> mask_arr;
|
||||
if (has_arr_mask) {
|
||||
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||
}
|
||||
|
||||
if (supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
||||
if (has_sinks_) {
|
||||
@@ -524,7 +555,17 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||
}
|
||||
} else {
|
||||
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
|
||||
sdpa_cudnn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale_,
|
||||
out,
|
||||
stats,
|
||||
do_causal_,
|
||||
mask_arr,
|
||||
output_logsumexp_,
|
||||
s);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -544,13 +585,21 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
||||
|
||||
auto& s = stream();
|
||||
|
||||
assert(inputs.size() == 6);
|
||||
assert(inputs.size() >= 6);
|
||||
int primals_size = inputs.size() - 3;
|
||||
bool has_arr_mask = primals_size > 3 + has_sinks_;
|
||||
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
array o = prepare_sdpa_input(inputs[3], s);
|
||||
array stats = prepare_sdpa_input(inputs[4], s);
|
||||
array d_o = prepare_sdpa_input(inputs[5], s);
|
||||
array o = prepare_sdpa_input(inputs[primals_size], s);
|
||||
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
|
||||
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
|
||||
|
||||
std::optional<array> mask_arr;
|
||||
if (has_arr_mask) {
|
||||
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||
}
|
||||
|
||||
assert(outputs.size() == 3);
|
||||
auto& d_q = outputs[0];
|
||||
@@ -558,7 +607,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
|
||||
auto& d_v = outputs[2];
|
||||
|
||||
sdpa_backward_cudnn(
|
||||
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
|
||||
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
@@ -569,6 +569,10 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
return !(supports_sdpa_full || supports_sdpa_vector);
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
|
||||
@@ -36,6 +36,10 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttention::supports_bool_mask() {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fast::ScaledDotProductAttentionVJP::use_fallback(
|
||||
const array& q,
|
||||
Stream s) {
|
||||
|
||||
11
mlx/fast.cpp
11
mlx/fast.cpp
@@ -800,6 +800,15 @@ array scaled_dot_product_attention(
|
||||
is_training,
|
||||
output_logsumexp,
|
||||
stream)) {
|
||||
if (has_bool_mask && !ScaledDotProductAttention::supports_bool_mask()) {
|
||||
// Convert bool mask to additive mask.
|
||||
float inf = std::numeric_limits<float>::infinity();
|
||||
array& mask = inputs[3];
|
||||
mask = where(
|
||||
mask,
|
||||
full_like(mask, 0, final_type, s),
|
||||
full_like(mask, -inf, final_type, s));
|
||||
}
|
||||
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||
auto primitive = std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
|
||||
@@ -839,7 +848,7 @@ std::vector<array> ScaledDotProductAttention::vjp(
|
||||
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
for (int i = 0; i < /* outputs size */ 3; ++i) {
|
||||
shapes.push_back(primals[i].shape());
|
||||
dtypes.push_back(primals[i].dtype());
|
||||
}
|
||||
|
||||
@@ -228,6 +228,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
bool is_training,
|
||||
bool output_logsumexp,
|
||||
Stream s);
|
||||
static bool supports_bool_mask();
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
|
||||
@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||
|
||||
def test_sdpa_grad(self):
|
||||
B, N_kv, T, D = (2, 8, 128, 64)
|
||||
scale = D**-0.5
|
||||
|
||||
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
|
||||
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
|
||||
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
|
||||
f4 = lambda q, k, v: (
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
).sum()
|
||||
|
||||
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
||||
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
||||
|
||||
def test_vjp(slow, fast, primals):
|
||||
cotan = mx.ones_like(primals[0])
|
||||
o1, vjp1 = mx.vjp(slow, primals, [cotan])
|
||||
o2, vjp2 = mx.vjp(fast, primals, [cotan])
|
||||
|
||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||
|
||||
def test_grad(slow, fast, args):
|
||||
g1 = mx.grad(slow)(*args)
|
||||
g2 = mx.grad(fast)(*args)
|
||||
|
||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||
|
||||
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
|
||||
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
loss_mask_fast = lambda q, k, v, mask: (
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
).sum()
|
||||
|
||||
B, N_kv, T, D = (2, 8, 128, 64)
|
||||
scale = D**-0.5
|
||||
|
||||
for N_q in (8, 32):
|
||||
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
||||
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
|
||||
cotan = mx.ones_like(q)
|
||||
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
|
||||
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
|
||||
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
|
||||
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
|
||||
|
||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||
for mask in (mask_additive, mask_bool):
|
||||
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
|
||||
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
|
||||
|
||||
g1 = mx.grad(f3)(q, k, v)
|
||||
g2 = mx.grad(f4)(q, k, v)
|
||||
for mask in (None, "causal"):
|
||||
sdpa_slow = lambda q, k, v: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
|
||||
|
||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||
loss_slow = lambda q, k, v: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
test_grad(loss_slow, loss_fast, [q, k, v])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user