mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix failing tests for CPU
This commit is contained in:
@@ -138,6 +138,7 @@ if(MLX_ENABLE_NAX)
|
|||||||
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
||||||
|
|
||||||
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
|
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
|
||||||
|
build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS})
|
||||||
|
|
||||||
set(STEEL_NAX_ATTN_HEADERS
|
set(STEEL_NAX_ATTN_HEADERS
|
||||||
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
|
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
def test_qmm(self):
|
def test_qmm(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
dtype = mx.float16
|
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64, 32], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
[2, 4, 8], # bits
|
||||||
@@ -193,6 +193,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
|
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
|
||||||
|
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
def test_qmm_vjp(self):
|
def test_qmm_vjp(self):
|
||||||
@@ -842,16 +844,16 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2, k3 = mx.random.split(key, 3)
|
k1, k2, k3 = mx.random.split(key, 3)
|
||||||
dtype = mx.float16
|
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||||
|
|
||||||
for L, K, D, E, I, transpose, mode in parameters:
|
for L, K, D, E, I, transpose, mode in parameters:
|
||||||
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
||||||
if mode == "mxfp4":
|
if mode == "mxfp4":
|
||||||
group_size = 32
|
group_size = 32
|
||||||
dtype = mx.bfloat16
|
dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32
|
||||||
else:
|
else:
|
||||||
group_size = 64
|
group_size = 64
|
||||||
dtype = mx.float16
|
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||||
|
|
||||||
K, D = (K, D) if transpose else (D, K)
|
K, D = (K, D) if transpose else (D, K)
|
||||||
ishape = (L, I)
|
ishape = (L, I)
|
||||||
|
|||||||
Reference in New Issue
Block a user