This commit is contained in:
Alex Barron 2024-12-06 10:26:54 -08:00
parent 769704653a
commit 82a956c1d9

View File

@ -128,10 +128,13 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
B = 1
H = 32
dtypes = [mx.float32]
if self.is_apple_silcon:
dtypes.append(mx.float16)
tests = product(
[1, 7, 9, 32, 63, 67, 129, 2000], # sequence length
[False, True], # gqa
[mx.float32, mx.float16],
dtypes,
[4, 8], # bits
)
for sequence_length, do_gqa, dtype, bits in tests: