mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Fix test (#785)
This commit is contained in:
parent
0787724c44
commit
859ae15a54
@ -32,12 +32,8 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale):
|
|||||||
return mlx_primitives_sdpa(q, k, v, scale)
|
return mlx_primitives_sdpa(q, k, v, scale)
|
||||||
|
|
||||||
|
|
||||||
class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
|
class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||||
@property
|
def test_fast_sdpa(self):
|
||||||
def dtypes(self):
|
|
||||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
|
||||||
|
|
||||||
def test_fast_inference_sdpa(self):
|
|
||||||
|
|
||||||
# Not yet supported:
|
# Not yet supported:
|
||||||
# * K pre-transposed in kernel, V pre-transposed in kernel
|
# * K pre-transposed in kernel, V pre-transposed in kernel
|
||||||
@ -65,9 +61,13 @@ class TestFastInferenceSDPA(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
B = 1
|
B = 1
|
||||||
H = 32
|
H = 32
|
||||||
|
dtypes = [np.float32]
|
||||||
|
if not self.is_linux:
|
||||||
|
dtypes.append(np.half)
|
||||||
|
|
||||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
||||||
for DO_GQA in [0, 1]:
|
for DO_GQA in [0, 1]:
|
||||||
for DTYPE in [np.float32, np.half]:
|
for DTYPE in dtypes:
|
||||||
n_kv_heads = 8 if DO_GQA else 32
|
n_kv_heads = 8 if DO_GQA else 32
|
||||||
q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
|
q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
|
||||||
k_npy = np.random.normal(
|
k_npy = np.random.normal(
|
||||||
|
Loading…
Reference in New Issue
Block a user