Update dtypes for quanitzed tests based on if gpu is being used

This commit is contained in:
Jagrit Digani
2025-11-19 13:21:27 -08:00
parent 75f4788b29
commit a72406b928
4 changed files with 16 additions and 11 deletions

View File

@@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
tests = product(
[128, 64, 32], # group_size
[2, 4, 8], # bits
@@ -195,7 +195,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
self.assertLess((y_q - y_hat).abs().max(), tol)
def test_qmm_vjp(self):
key = mx.random.key(0)
@@ -844,16 +844,20 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3)
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
for L, K, D, E, I, transpose, mode in parameters:
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
if mode == "mxfp4":
group_size = 32
dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32
dtype = (
mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32
)
else:
group_size = 64
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
dtype = (
mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
)
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)