[CUDA] cuDNN forward attention (#2743)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

* Separate sdpa kernels in another file

* Initial support for cuDNN SDPA

* Diable a few corner cases

* Remove scaled_dot_product_attention.h

* Use cuDNN attention for prefilling

* cuDNN SDPA requires Ampere and later

* Address reviews

* Do contiguous copy of inputs
This commit is contained in:
Cheng
2025-11-14 09:23:56 +09:00
committed by GitHub
parent b65f882df3
commit 3b2ffcefc3
7 changed files with 358 additions and 53 deletions

View File

@@ -168,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
Dk = 64
if self.is_apple_silicon or mx.cuda.is_available():
if mx.is_available(mx.gpu):
dtypes.append(np.half)
for SEQUENCE_LENGTH in [63, 129, 400]:
@@ -240,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
B = 1
H = 32
dtypes = [np.float32]
if self.is_apple_silicon or mx.cuda.is_available():
if mx.is_available(mx.gpu):
dtypes.append(np.half)
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
@@ -549,12 +549,8 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
class TestSDPA(mlx_tests.MLXTestCase):
@property
def dtypes(self):
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
def test_sdpa(self):
if not mx.metal.is_available():
if not mx.is_available(mx.gpu):
return
# fmt: off
@@ -578,10 +574,11 @@ class TestSDPA(mlx_tests.MLXTestCase):
# fmt: on
shapes = shapes_64 + shapes_128
dtypes = ["float32", "float16"]
masks = [None, "additive", "bool", "causal"]
transposes = (False, True)
for dtype in self.dtypes:
for dtype in dtypes:
for t in transposes:
for mask_str in masks:
for B, qL, kL, D, qH, kH in shapes: