mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
This commit is contained in:
@@ -739,37 +739,69 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out, expected, atol=1e-5))
|
||||
|
||||
def test_sdpa_grad(self):
|
||||
B, N_kv, T, D = (2, 8, 128, 64)
|
||||
scale = D**-0.5
|
||||
|
||||
f1 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale)
|
||||
f2 = lambda q, k, v: mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
|
||||
f3 = lambda q, k, v: mlx_ref_attn(q, k, v, scale=scale).sum()
|
||||
f4 = lambda q, k, v: (
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||
).sum()
|
||||
|
||||
# High tolerance due to cuDNN SDPA kernel requiring tf32.
|
||||
tolerance = {"rtol": 1e-2, "atol": 1e-2}
|
||||
|
||||
def test_vjp(slow, fast, primals):
|
||||
cotan = mx.ones_like(primals[0])
|
||||
o1, vjp1 = mx.vjp(slow, primals, [cotan])
|
||||
o2, vjp2 = mx.vjp(fast, primals, [cotan])
|
||||
|
||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||
|
||||
def test_grad(slow, fast, args):
|
||||
g1 = mx.grad(slow)(*args)
|
||||
g2 = mx.grad(fast)(*args)
|
||||
|
||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||
|
||||
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
|
||||
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
loss_mask_fast = lambda q, k, v, mask: (
|
||||
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
).sum()
|
||||
|
||||
B, N_kv, T, D = (2, 8, 128, 64)
|
||||
scale = D**-0.5
|
||||
|
||||
for N_q in (8, 32):
|
||||
q = mx.random.normal(shape=(B, N_q, T, D), dtype=mx.float16)
|
||||
k = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
v = mx.random.normal(shape=(B, N_kv, T, D), dtype=mx.float16)
|
||||
|
||||
cotan = mx.ones_like(q)
|
||||
o1, vjp1 = mx.vjp(f1, [q, k, v], [cotan])
|
||||
o2, vjp2 = mx.vjp(f2, [q, k, v], [cotan])
|
||||
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
|
||||
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
|
||||
|
||||
self.assertTrue(mx.allclose(o1[0], o2[0], **tolerance))
|
||||
for i in range(3):
|
||||
self.assertTrue(mx.allclose(vjp1[i], vjp2[i], **tolerance))
|
||||
for mask in (mask_additive, mask_bool):
|
||||
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
|
||||
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
|
||||
|
||||
g1 = mx.grad(f3)(q, k, v)
|
||||
g2 = mx.grad(f4)(q, k, v)
|
||||
for mask in (None, "causal"):
|
||||
sdpa_slow = lambda q, k, v: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
sdpa_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
)
|
||||
test_vjp(sdpa_slow, sdpa_fast, [q, k, v])
|
||||
|
||||
self.assertTrue(mx.allclose(g1, g2, **tolerance))
|
||||
loss_slow = lambda q, k, v: mlx_ref_attn(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
loss_fast = lambda q, k, v: mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=scale, mask=mask
|
||||
).sum()
|
||||
test_grad(loss_slow, loss_fast, [q, k, v])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user