[CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled

This commit is contained in:
Cheng
2025-11-26 11:08:58 +09:00
committed by GitHub
parent c9f4dc851f
commit 704fd1ae28
6 changed files with 146 additions and 47 deletions

View File

@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
def test_sdpa_grad(self):
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
f4 = lambda q, k, v: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
).sum()
# High tolerance due to cuDNN SDPA kernel requiring tf32.
tolerance = {"rtol": 1e-2, "atol": 1e-2}
def test_vjp(slow, fast, primals):
cotan = mx.ones_like(primals[0])
o1, vjp1 = mx.vjp(slow, primals, [cotan])
o2, vjp2 = mx.vjp(fast, primals, [cotan])
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
def test_grad(slow, fast, args):
g1 = mx.grad(slow)(*args)
g2 = mx.grad(fast)(*args)
self.assertTrue(mx.allclose(g1, g2, **tolerance))
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_mask_fast = lambda q, k, v, mask: (
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
).sum()
B, N_kv, T, D = (2, 8, 128, 64)
scale = D**-0.5
for N_q in (8, 32):
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
cotan = mx.ones_like(q)
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
for i in range(3):
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
for mask in (mask_additive, mask_bool):
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
g1 = mx.grad(f3)(q, k, v)
g2 = mx.grad(f4)(q, k, v)
for mask in (None, "causal"):
sdpa_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
)
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
self.assertTrue(mx.allclose(g1, g2, **tolerance))
loss_slow = lambda q, k, v: mlx_ref_attn(
q, k, v, scale=scale, mask=mask
).sum()
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
).sum()
test_grad(loss_slow, loss_fast, [q, k, v])
if __name__ == "__main__":