diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3e64e82da..e4f625383 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -361,7 +361,7 @@ void steel_matmul_regular_axpby( if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && - (a.dtype() != float32 || env::enable_tf32())) { + (env::enable_tf32() || a.dtype() != float32)) { return steel_matmul_regular_axpby_nax( /* const Stream& s = */ s, /* metal::Device& d = */ d, @@ -2201,7 +2201,8 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { if (__builtin_available( macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && - (a.dtype() != float32 || env::enable_tf32())) { + !issubdtype(a.dtype(), complexfloating) && + (env::enable_tf32() || a.dtype() != float32)) { return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); } } diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index f2c4e7d2f..55b69b9ca 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -674,7 +674,7 @@ void qmm( if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && transpose && (K % 64 == 0) && - (x.dtype() != float32 || env::enable_tf32())) { + (env::enable_tf32() || x.dtype() != float32)) { return qmm_nax( /* const array& x = */ x, /* const array& w = */ w, @@ -776,7 +776,7 @@ void gather_qmm( if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && transpose && (K % 64 == 0) && - (x.dtype() != float32 || env::enable_tf32())) { + (env::enable_tf32() || x.dtype() != float32)) { return gather_qmm_nax( /* const array& x = */ x, /* const array& w = */ w, @@ -1130,7 +1130,7 @@ void gather_qmm_rhs( if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && transpose && - (x_.dtype() != float32 || env::enable_tf32())) { + (env::enable_tf32() || x_.dtype() != float32)) { return gather_qmm_rhs_nax( /* const array& x_ = */ x_, /* const array& w_ = */ w_, diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 156093757..d3920b55d 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -166,7 +166,7 @@ void sdpa_full_self_attention_metal( #ifdef MLX_ENABLE_NAX if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && q.shape(3) != 80 && - (q.dtype() != float32 || env::enable_tf32())) { + (env::enable_tf32() || q.dtype() != float32)) { return sdpa_full_self_attention_nax( /* const Stream& s = */ s, /* metal::Device& d = */ d, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 2ba4b64d5..ce63544b4 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) - dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 + dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 tests = product( [128, 64, 32], # group_size [2, 4, 8], # bits @@ -195,7 +195,7 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) tol = 1e-3 if dtype == mx.float32 else 1.5e-3 - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + self.assertLess((y_q - y_hat).abs().max(), tol) def test_qmm_vjp(self): key = mx.random.key(0) @@ -844,16 +844,20 @@ class TestQuantized(mlx_tests.MLXTestCase): key = mx.random.key(0) k1, k2, k3 = mx.random.split(key, 3) - dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 + dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 for L, K, D, E, I, transpose, mode in parameters: with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode): if mode == "mxfp4": group_size = 32 - dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32 + dtype = ( + mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32 + ) else: group_size = 64 - dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 + dtype = ( + mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 + ) K, D = (K, D) if transpose else (D, K) ishape = (L, I)