mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
add test
This commit is contained in:
parent
12a4d89a7c
commit
3507c104a5
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
@ -113,61 +114,63 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
R = 1
|
R = 1
|
||||||
Dk = 128
|
Dk = 128
|
||||||
scale = float(1.0 / np.sqrt(128.0))
|
scale = float(1.0 / np.sqrt(128.0))
|
||||||
q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32)
|
q = mx.random.normal(shape=(1, 32, R, Dk))
|
||||||
k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
|
k = mx.random.normal(shape=(1, 32, L, Dk))
|
||||||
v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32)
|
v = mx.random.normal(shape=(1, 32, L, Dk))
|
||||||
|
|
||||||
q_mlx = mx.array(q_npy)
|
reference = mlx_primitives_sdpa(q, k, v, scale)
|
||||||
k_mlx = mx.array(k_npy)
|
|
||||||
v_mlx = mx.array(v_npy)
|
|
||||||
|
|
||||||
reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale)
|
o = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||||
|
|
||||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
self.assertListEqual(list(reference.shape), list(o.shape))
|
||||||
q_mlx, k_mlx, v_mlx, scale=scale, mask=None
|
self.assertTrue(mx.allclose(o, reference, atol=1e-4))
|
||||||
)
|
|
||||||
|
|
||||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
|
||||||
self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4))
|
|
||||||
|
|
||||||
B = 1
|
B = 1
|
||||||
H = 32
|
H = 32
|
||||||
dtypes = [np.float32]
|
|
||||||
if self.is_apple_silicon:
|
|
||||||
dtypes.append(np.half)
|
|
||||||
|
|
||||||
for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]:
|
tests = product(
|
||||||
for DO_GQA in [0, 1]:
|
[1, 7, 9, 32, 63, 67, 129, 2000], # sequence length
|
||||||
for DTYPE in dtypes:
|
[False, True], # gqa
|
||||||
n_kv_heads = 8 if DO_GQA else 32
|
[mx.float32, mx.float16],
|
||||||
q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE)
|
[4, 8], # bits
|
||||||
k_npy = np.random.normal(
|
)
|
||||||
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
|
for sequence_length, do_gqa, dtype, bits in tests:
|
||||||
).astype(DTYPE)
|
with self.subTest(
|
||||||
v_npy = np.random.normal(
|
sequence_length=sequence_length, gqa=do_gqa, dtype=dtype, bits=bits
|
||||||
0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk)
|
):
|
||||||
).astype(DTYPE)
|
n_kv_heads = 8 if do_gqa else 32
|
||||||
|
q = mx.random.normal(shape=(B, H, R, Dk), dtype=dtype)
|
||||||
|
k = mx.random.normal(
|
||||||
|
shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype
|
||||||
|
)
|
||||||
|
v = mx.random.normal(
|
||||||
|
shape=(B, n_kv_heads, sequence_length, Dk), dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
q_mlx = mx.array(q_npy)
|
k_q = mx.quantize(k, bits=bits)
|
||||||
k_mlx = mx.array(k_npy)
|
v_q = mx.quantize(v, bits=bits)
|
||||||
v_mlx = mx.array(v_npy)
|
k_d = mx.dequantize(*k_q, bits=bits)
|
||||||
|
v_d = mx.dequantize(*v_q, bits=bits)
|
||||||
|
|
||||||
reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
|
reference = mlx_primitives_sdpa_with_gqa(q, k_d, v_d, scale)
|
||||||
o_mlx = mx.fast.scaled_dot_product_attention(
|
o = mx.fast.scaled_dot_product_attention(q, k_d, v_d, scale=scale)
|
||||||
q_mlx, k_mlx, v_mlx, scale=scale
|
o_q = mx.fast.quantized_scaled_dot_product_attention(
|
||||||
)
|
q, *k_q, *v_q, scale=scale, bits=bits
|
||||||
|
)
|
||||||
|
|
||||||
self.assertListEqual(list(reference.shape), list(o_mlx.shape))
|
self.assertListEqual(list(reference.shape), list(o.shape))
|
||||||
rtol = 1e-5
|
rtol = 1e-5
|
||||||
atol = 1e-1
|
atol = 1e-1
|
||||||
|
|
||||||
if SEQUENCE_LENGTH > 500:
|
if sequence_length > 500:
|
||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
|
|
||||||
if DTYPE == np.half:
|
if dtype == mx.float16:
|
||||||
rtol = 1e-2
|
rtol = 1e-2
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
|
# np.testing.assert_allclose(o_q, reference, rtol=rtol, atol=atol)
|
||||||
|
self.assertTrue(mx.allclose(o_q, reference, rtol=rtol, atol=atol))
|
||||||
|
self.assertTrue(mx.allclose(o, reference, rtol=rtol, atol=atol))
|
||||||
|
|
||||||
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
q = mx.random.normal(shape=(1, 32, 1, Dk))
|
||||||
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
k = mx.random.normal(shape=(1, 32, 32, Dk))
|
||||||
|
Loading…
Reference in New Issue
Block a user