mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
fix test
This commit is contained in:
parent
769704653a
commit
82a956c1d9
@ -128,10 +128,13 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
B = 1
|
||||
H = 32
|
||||
|
||||
dtypes = [mx.float32]
|
||||
if self.is_apple_silcon:
|
||||
dtypes.append(mx.float16)
|
||||
tests = product(
|
||||
[1, 7, 9, 32, 63, 67, 129, 2000], # sequence length
|
||||
[False, True], # gqa
|
||||
[mx.float32, mx.float16],
|
||||
dtypes,
|
||||
[4, 8], # bits
|
||||
)
|
||||
for sequence_length, do_gqa, dtype, bits in tests:
|
||||
|
Loading…
Reference in New Issue
Block a user