Update dtypes for quanitzed tests based on if gpu is being used

This commit is contained in:
Jagrit Digani
2025-11-19 13:21:27 -08:00
parent 75f4788b29
commit a72406b928
4 changed files with 16 additions and 11 deletions

View File

@@ -361,7 +361,7 @@ void steel_matmul_regular_axpby(
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
(a.dtype() != float32 || env::enable_tf32())) { (env::enable_tf32() || a.dtype() != float32)) {
return steel_matmul_regular_axpby_nax<CHECK_AB>( return steel_matmul_regular_axpby_nax<CHECK_AB>(
/* const Stream& s = */ s, /* const Stream& s = */ s,
/* metal::Device& d = */ d, /* metal::Device& d = */ d,
@@ -2201,7 +2201,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (__builtin_available( if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && if (metal::is_nax_available() &&
(a.dtype() != float32 || env::enable_tf32())) { !issubdtype(a.dtype(), complexfloating) &&
(env::enable_tf32() || a.dtype() != float32)) {
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
} }
} }

View File

@@ -674,7 +674,7 @@ void qmm(
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) && if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(x.dtype() != float32 || env::enable_tf32())) { (env::enable_tf32() || x.dtype() != float32)) {
return qmm_nax( return qmm_nax(
/* const array& x = */ x, /* const array& x = */ x,
/* const array& w = */ w, /* const array& w = */ w,
@@ -776,7 +776,7 @@ void gather_qmm(
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && (K % 64 == 0) && if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
(x.dtype() != float32 || env::enable_tf32())) { (env::enable_tf32() || x.dtype() != float32)) {
return gather_qmm_nax( return gather_qmm_nax(
/* const array& x = */ x, /* const array& x = */ x,
/* const array& w = */ w, /* const array& w = */ w,
@@ -1130,7 +1130,7 @@ void gather_qmm_rhs(
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && transpose && if (metal::is_nax_available() && transpose &&
(x_.dtype() != float32 || env::enable_tf32())) { (env::enable_tf32() || x_.dtype() != float32)) {
return gather_qmm_rhs_nax( return gather_qmm_rhs_nax(
/* const array& x_ = */ x_, /* const array& x_ = */ x_,
/* const array& w_ = */ w_, /* const array& w_ = */ w_,

View File

@@ -166,7 +166,7 @@ void sdpa_full_self_attention_metal(
#ifdef MLX_ENABLE_NAX #ifdef MLX_ENABLE_NAX
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
if (metal::is_nax_available() && q.shape(3) != 80 && if (metal::is_nax_available() && q.shape(3) != 80 &&
(q.dtype() != float32 || env::enable_tf32())) { (env::enable_tf32() || q.dtype() != float32)) {
return sdpa_full_self_attention_nax( return sdpa_full_self_attention_nax(
/* const Stream& s = */ s, /* const Stream& s = */ s,
/* metal::Device& d = */ d, /* metal::Device& d = */ d,

View File

@@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm(self): def test_qmm(self):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
tests = product( tests = product(
[128, 64, 32], # group_size [128, 64, 32], # group_size
[2, 4, 8], # bits [2, 4, 8], # bits
@@ -195,7 +195,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
tol = 1e-3 if dtype == mx.float32 else 1.5e-3 tol = 1e-3 if dtype == mx.float32 else 1.5e-3
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), tol)
def test_qmm_vjp(self): def test_qmm_vjp(self):
key = mx.random.key(0) key = mx.random.key(0)
@@ -844,16 +844,20 @@ class TestQuantized(mlx_tests.MLXTestCase):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3) k1, k2, k3 = mx.random.split(key, 3)
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
for L, K, D, E, I, transpose, mode in parameters: for L, K, D, E, I, transpose, mode in parameters:
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode): with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
if mode == "mxfp4": if mode == "mxfp4":
group_size = 32 group_size = 32
dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32 dtype = (
mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32
)
else: else:
group_size = 64 group_size = 64
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32 dtype = (
mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
)
K, D = (K, D) if transpose else (D, K) K, D = (K, D) if transpose else (D, K)
ishape = (L, I) ishape = (L, I)