Fix failing tests for CPU

This commit is contained in:
Jagrit Digani
2025-11-19 12:37:58 -08:00
parent 166dfac5cf
commit ff1afd8b3d
2 changed files with 7 additions and 4 deletions

View File

@@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
dtype = mx.float16
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
@@ -193,6 +193,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
)
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
self.assertEqual(y_q.shape, y_hat.shape)
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_vjp(self):
@@ -842,16 +844,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3)
dtype = mx.float16
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
for L, K, D, E, I, transpose, mode in parameters:
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
if mode == "mxfp4":
group_size = 32
dtype = mx.bfloat16
dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32
else:
group_size = 64
dtype = mx.float16
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)