diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 514f6038c..5215fb346 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -138,6 +138,7 @@ if(MLX_ENABLE_NAX) build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS}) + build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS}) set(STEEL_NAX_ATTN_HEADERS steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index ee6cef6ec..2ba4b64d5 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) - dtype = mx.float16 + dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 tests = product( [128, 64, 32], # group_size [2, 4, 8], # bits @@ -193,6 +193,8 @@ class TestQuantized(mlx_tests.MLXTestCase): ) y_hat = (x @ w_hat.T) if transposed else (x @ w_hat) self.assertEqual(y_q.shape, y_hat.shape) + + tol = 1e-3 if dtype == mx.float32 else 1.5e-3 self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_qmm_vjp(self): @@ -842,16 +844,16 @@ class TestQuantized(mlx_tests.MLXTestCase): key = mx.random.key(0) k1, k2, k3 = mx.random.split(key, 3) - dtype = mx.float16 + dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 for L, K, D, E, I, transpose, mode in parameters: with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode): if mode == "mxfp4": group_size = 32 - dtype = mx.bfloat16 + dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32 else: group_size = 64 - dtype = mx.float16 + dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 K, D = (K, D) if transpose else (D, K) ishape = (L, I)