Fix failing tests for CPU

This commit is contained in:
Jagrit Digani
2025-11-19 12:37:58 -08:00
parent 166dfac5cf
commit ff1afd8b3d
2 changed files with 7 additions and 4 deletions

View File

@@ -138,6 +138,7 @@ if(MLX_ENABLE_NAX)
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS}) build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS})
set(STEEL_NAX_ATTN_HEADERS set(STEEL_NAX_ATTN_HEADERS
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h

View File

@@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm(self): def test_qmm(self):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)
dtype = mx.float16 dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
tests = product( tests = product(
[128, 64, 32], # group_size [128, 64, 32], # group_size
[2, 4, 8], # bits [2, 4, 8], # bits
@@ -193,6 +193,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
) )
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat) y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_vjp(self): def test_qmm_vjp(self):
@@ -842,16 +844,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3) k1, k2, k3 = mx.random.split(key, 3)
dtype = mx.float16 dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
for L, K, D, E, I, transpose, mode in parameters: for L, K, D, E, I, transpose, mode in parameters:
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode): with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
if mode == "mxfp4": if mode == "mxfp4":
group_size = 32 group_size = 32
dtype = mx.bfloat16 dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32
else: else:
group_size = 64 group_size = 64
dtype = mx.float16 dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
K, D = (K, D) if transpose else (D, K) K, D = (K, D) if transpose else (D, K)
ishape = (L, I) ishape = (L, I)