mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix failing tests for CPU
This commit is contained in:
@@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
dtype = mx.float16
|
||||
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
@@ -193,6 +193,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
)
|
||||
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
|
||||
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmm_vjp(self):
|
||||
@@ -842,16 +844,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3 = mx.random.split(key, 3)
|
||||
dtype = mx.float16
|
||||
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
dtype = mx.bfloat16
|
||||
dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32
|
||||
else:
|
||||
group_size = 64
|
||||
dtype = mx.float16
|
||||
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
|
||||
Reference in New Issue
Block a user