mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Update dtypes for quanitzed tests based on if gpu is being used
This commit is contained in:
@@ -361,7 +361,7 @@ void steel_matmul_regular_axpby(
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
||||
(a.dtype() != float32 || env::enable_tf32())) {
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
@@ -2201,7 +2201,8 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (__builtin_available(
|
||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() &&
|
||||
(a.dtype() != float32 || env::enable_tf32())) {
|
||||
!issubdtype(a.dtype(), complexfloating) &&
|
||||
(env::enable_tf32() || a.dtype() != float32)) {
|
||||
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -674,7 +674,7 @@ void qmm(
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(x.dtype() != float32 || env::enable_tf32())) {
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
@@ -776,7 +776,7 @@ void gather_qmm(
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose && (K % 64 == 0) &&
|
||||
(x.dtype() != float32 || env::enable_tf32())) {
|
||||
(env::enable_tf32() || x.dtype() != float32)) {
|
||||
return gather_qmm_nax(
|
||||
/* const array& x = */ x,
|
||||
/* const array& w = */ w,
|
||||
@@ -1130,7 +1130,7 @@ void gather_qmm_rhs(
|
||||
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && transpose &&
|
||||
(x_.dtype() != float32 || env::enable_tf32())) {
|
||||
(env::enable_tf32() || x_.dtype() != float32)) {
|
||||
return gather_qmm_rhs_nax(
|
||||
/* const array& x_ = */ x_,
|
||||
/* const array& w_ = */ w_,
|
||||
|
||||
@@ -166,7 +166,7 @@ void sdpa_full_self_attention_metal(
|
||||
#ifdef MLX_ENABLE_NAX
|
||||
if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||
(q.dtype() != float32 || env::enable_tf32())) {
|
||||
(env::enable_tf32() || q.dtype() != float32)) {
|
||||
return sdpa_full_self_attention_nax(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
|
||||
@@ -163,7 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_qmm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
@@ -195,7 +195,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
|
||||
tol = 1e-3 if dtype == mx.float32 else 1.5e-3
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
self.assertLess((y_q - y_hat).abs().max(), tol)
|
||||
|
||||
def test_qmm_vjp(self):
|
||||
key = mx.random.key(0)
|
||||
@@ -844,16 +844,20 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
key = mx.random.key(0)
|
||||
k1, k2, k3 = mx.random.split(key, 3)
|
||||
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||
dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode):
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
dtype = mx.bfloat16 if mx.is_available(mx.gpu) else mx.float32
|
||||
dtype = (
|
||||
mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
)
|
||||
else:
|
||||
group_size = 64
|
||||
dtype = mx.float16 if mx.is_available(mx.gpu) else mx.float32
|
||||
dtype = (
|
||||
mx.float16 if (mx.default_device() == mx.gpu) else mx.float32
|
||||
)
|
||||
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
|
||||
Reference in New Issue
Block a user