This commit is contained in:
Awni Hannun 2024-03-04 23:02:27 -08:00 committed by GitHub
parent 0787724c44
commit 859ae15a54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,12 +32,8 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
return mlx_primitives_sdpa(q, k, v, scale)
class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
@property
def dtypes(self):
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
def test_fast_inference_sdpa(self):
class TestFastSDPA(mlx_tests.MLXTestCase):
def test_fast_sdpa(self):
# Not yet supported:
# * K pre-transposed in kernel, V pre-transposed in kernel
@ -65,9 +61,13 @@ class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
B = 1
H = 32
dtypes = [np.float32]
if not self.is_linux:
dtypes.append(np.half)
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
for DO_GQA in [0, 1]:
for DTYPE in [np.float32, np.half]:
for DTYPE in dtypes:
n_kv_heads = 8 if DO_GQA else 32
q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
k_npy = np.random.normal(