mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
[CUDA] cuDNN forward attention (#2743)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Separate sdpa kernels in another file * Initial support for cuDNN SDPA * Diable a few corner cases * Remove scaled_dot_product_attention.h * Use cuDNN attention for prefilling * cuDNN SDPA requires Ampere and later * Address reviews * Do contiguous copy of inputs
This commit is contained in:
@@ -168,7 +168,7 @@ class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
Dk = 64
|
||||
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [63, 129, 400]:
|
||||
@@ -240,7 +240,7 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
B = 1
|
||||
H = 32
|
||||
dtypes = [np.float32]
|
||||
if self.is_apple_silicon or mx.cuda.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
||||
@@ -549,12 +549,8 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
|
||||
class TestSDPA(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def test_sdpa(self):
|
||||
if not mx.metal.is_available():
|
||||
if not mx.is_available(mx.gpu):
|
||||
return
|
||||
|
||||
# fmt: off
|
||||
@@ -578,10 +574,11 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_128
|
||||
dtypes = ["float32", "float16"]
|
||||
masks = [None, "additive", "bool", "causal"]
|
||||
transposes = (False, True)
|
||||
|
||||
for dtype in self.dtypes:
|
||||
for dtype in dtypes:
|
||||
for t in transposes:
|
||||
for mask_str in masks:
|
||||
for B, qL, kL, D, qH, kH in shapes:
|
||||
|
||||
Reference in New Issue
Block a user