mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix test
This commit is contained in:
parent
769704653a
commit
82a956c1d9
@ -128,10 +128,13 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
B = 1
|
B = 1
|
||||||
H = 32
|
H = 32
|
||||||
|
|
||||||
|
dtypes = [mx.float32]
|
||||||
|
if self.is_apple_silcon:
|
||||||
|
dtypes.append(mx.float16)
|
||||||
tests = product(
|
tests = product(
|
||||||
[1, 7, 9, 32, 63, 67, 129, 2000], # sequence length
|
[1, 7, 9, 32, 63, 67, 129, 2000], # sequence length
|
||||||
[False, True], # gqa
|
[False, True], # gqa
|
||||||
[mx.float32, mx.float16],
|
dtypes,
|
||||||
[4, 8], # bits
|
[4, 8], # bits
|
||||||
)
|
)
|
||||||
for sequence_length, do_gqa, dtype, bits in tests:
|
for sequence_length, do_gqa, dtype, bits in tests:
|
||||||
|
Loading…
Reference in New Issue
Block a user