mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix test (#785)
This commit is contained in:
parent
0787724c44
commit
859ae15a54
@ -32,12 +32,8 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
||||
return mlx_primitives_sdpa(q, k, v, scale)
|
||||
|
||||
|
||||
class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def test_fast_inference_sdpa(self):
|
||||
class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
def test_fast_sdpa(self):
|
||||
|
||||
# Not yet supported:
|
||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||
@ -65,9 +61,13 @@ class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
B = 1
|
||||
H = 32
|
||||
dtypes = [np.float32]
|
||||
if not self.is_linux:
|
||||
dtypes.append(np.half)
|
||||
|
||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
||||
for DO_GQA in [0, 1]:
|
||||
for DTYPE in [np.float32, np.half]:
|
||||
for DTYPE in dtypes:
|
||||
n_kv_heads = 8 if DO_GQA else 32
|
||||
q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
|
||||
k_npy = np.random.normal(
|
||||
|
Loading…
Reference in New Issue
Block a user