From 18807aae0bc7701815f3173cbe44e746f00342f8 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 19 Nov 2025 11:13:26 -0800 Subject: [PATCH] Add fixes to int QMMs (CI passing) --- mlx/backend/metal/kernels/quantized_nax.h | 51 ++++---- mlx/backend/metal/matmul.cpp | 5 +- mlx/backend/metal/quantized.cpp | 151 +++++++++++++++++++++- python/tests/test_quantized.py | 57 +++++--- 4 files changed, 221 insertions(+), 43 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h index 57d0a7211..ef0b8e368 100644 --- a/mlx/backend/metal/kernels/quantized_nax.h +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -860,9 +860,6 @@ METAL_FUNC void qmm_t_nax_tgp_impl( const short tm = SM * (simd_gid / WN); const short tn = SN * (simd_gid % WN); - const short lda_tgp = BK_padded; - const short ldb_tgp = BK_padded; - constexpr bool transpose_a = false; constexpr bool transpose_b = true; @@ -898,7 +895,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl( threadgroup_barrier(mem_flags::mem_threadgroup); -#pragma clang loop unroll(disable) + STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; @@ -911,7 +908,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl( Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); } - Btile.template load(Ws + tn * ldb_tgp + kk1); + Btile.template load(Ws + tn * BK_padded + kk1); tile_matmad_nax( Dtile, @@ -964,6 +961,8 @@ METAL_FUNC void qmm_n_nax_tgp_impl( uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); @@ -997,8 +996,8 @@ METAL_FUNC void qmm_n_nax_tgp_impl( y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - const short num_outs = min(BN, N - y_col); + // const short num_els = min(BM, M - y_row); + // const short num_outs = min(BN, N - y_col); loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); constexpr short UM = 16; @@ -1037,7 +1036,7 @@ METAL_FUNC void qmm_n_nax_tgp_impl( loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); -#pragma clang loop unroll(disable) + STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; NAXTile Btile; @@ -1408,10 +1407,17 @@ template < const short tm = SM * (simd_group_id / WN); const short tn = SN * (simd_group_id % WN); - const short sgp_sm = align_M ? SM : min(SM, short(M - (y_row + tm))); + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + using AccumType = float; using ASubTile = NAXSubTile; @@ -1467,11 +1473,10 @@ template < threadgroup_barrier(mem_flags::mem_threadgroup); -#pragma clang loop unroll(disable) + STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; - NAXTile - Btile; + NAXTile Btile; volatile int compiler_barrier; @@ -1506,15 +1511,15 @@ template < loader_w.load_safe(tile_w); threadgroup_barrier(mem_flags::mem_threadgroup); -#pragma clang loop unroll(disable) + STEEL_PRAGMA_NO_UNROLL for (int kk1 = 0; kk1 < BK; kk1 += SK) { NAXTile Atile; - NAXTile - Btile; + NAXTile Btile; volatile int compiler_barrier; - Atile.load_safe(xn + kk1, K, short2((BK - kk1), sgp_sm)); + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); if constexpr (transpose) { Btile.template load(Ws + tn * BK_padded + kk1); @@ -1535,23 +1540,23 @@ template < threadgroup_barrier(mem_flags::mem_threadgroup); + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + // Store results to device memory if constexpr (kAlignedN.value) { - if ((offset_next - offset) == BM) { + if (m_lo_lim == 0 && m_hi_lim == SM) { Dtile.store(y + tm * N + tn, N); } else { Dtile.store_slice( - y + tm * N + tn, - N, - short2(0, min(int(sgp_sm), max(0, offset - tm))), - short2(BN, min(int(sgp_sm), max(0, offset_next - tm)))); + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); } } else { Dtile.store_slice( y + tm * N + tn, N, - short2(0, max(0, offset - tm)), - short2(max(0, tgp_bn - tn), max(0, offset_next - tm))); + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); } }); }); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 94e9e80b9..3e64e82da 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1825,7 +1825,7 @@ void gather_mm_rhs_nax( base_name.reserve(64); concatenate( base_name, - "steel_gather_mm_rhs_mxu_n", + "steel_gather_mm_rhs_nax_n", transpose_b ? 't' : 'n', '_', type_to_name(a), @@ -2200,7 +2200,8 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { if (__builtin_available( macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && a.dtype() != float32) { + if (metal::is_nax_available() && + (a.dtype() != float32 || env::enable_tf32())) { return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); } } diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 15a5ceeb8..0594c02f2 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -539,6 +539,120 @@ void qmm_nax( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void gather_qmm_nax( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s, + const std::string& mode) { + int B = out.size() / M / N; + + int wm = 2; + int wn = 2; + int bm = 64; + int bn = 64; + int bk = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + + std::string kname; + kname.reserve(64); + bool aligned = N % 64 == 0; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + mode + (transpose ? "_gather_qmm_t_nax_" : "_gather_qmm_n_nax_"), + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn, + transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); + MTL::ComputePipelineState* kernel; + if (transpose) { + kernel = get_quantized_kernel_wrapped( + d, + kname, + "gather_qmm_t_nax_", + mode, + type_string, + group_size, + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn, + aligned); + } else { + kernel = get_quantized_kernel_wrapped( + d, + kname, + "gather_qmm_n_nax_", + mode, + type_string, + group_size, + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + } + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + #endif // MLX_ENABLE_NAX void qmm( @@ -559,8 +673,9 @@ void qmm( #ifdef MLX_ENABLE_NAX if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && mode == "affine" && (group_size >= 64) && - transpose && (M % 64 == 0) && (N % 64 == 0) && (K % 64 == 0)) { + if (metal::is_nax_available() && transpose && + (x.dtype() != float32 || env::enable_tf32()) && mode == "affine" && + (group_size >= 64) && (K % 64 == 0)) { return qmm_nax( /* const array& x = */ x, /* const array& w = */ w, @@ -658,6 +773,34 @@ void gather_qmm( metal::Device& d, const Stream& s, const std::string& mode) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && transpose && + (x.dtype() != float32 || env::enable_tf32()) && transpose && + mode == "affine" && (group_size >= 64) && (K % 64 == 0)) { + return gather_qmm_nax( + /* const array& x = */ x, + /* const array& w = */ w, + /* const array& scales = */ scales, + /* const std::optional& biases = */ biases, + /* const array& lhs_indices = */ lhs_indices, + /* const array& rhs_indices = */ rhs_indices, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string& mode = */ mode); + } + } + +#endif // MLX_ENABLE_NAX + int B = out.size() / M / N; int wm = 2; @@ -988,7 +1131,9 @@ void gather_qmm_rhs( #ifdef MLX_ENABLE_NAX if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { - if (metal::is_nax_available() && mode == "affine" && (group_size >= 64)) { + if (metal::is_nax_available() && + (x_.dtype() != float32 || env::enable_tf32()) && mode == "affine" && + (group_size >= 64)) { return gather_qmm_rhs_nax( /* const array& x_ = */ x_, /* const array& w_ = */ w_, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 5ae1d8104..ee6cef6ec 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -163,6 +163,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) + dtype = mx.float16 tests = product( [128, 64, 32], # group_size [2, 4, 8], # bits @@ -178,8 +179,13 @@ class TestQuantized(mlx_tests.MLXTestCase): bits=bits, transposed=transposed, ): - x = mx.random.normal(shape=(M, K), key=k1) - w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) + x = mx.random.normal(shape=(M, K), key=k1) / K**0.5 + w = ( + mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) + / K**0.5 + ) + x = x.astype(dtype) + w = w.astype(dtype) w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) y_q = mx.quantized_matmul( @@ -833,20 +839,34 @@ class TestQuantized(mlx_tests.MLXTestCase): (133, 512, 555, 4, 2, False, "affine"), (64, 512, 512, 4, 2, False, "affine"), ] + + key = mx.random.key(0) + k1, k2, k3 = mx.random.split(key, 3) + dtype = mx.float16 + for L, K, D, E, I, transpose, mode in parameters: with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode): if mode == "mxfp4": group_size = 32 + dtype = mx.bfloat16 else: group_size = 64 + dtype = mx.float16 + K, D = (K, D) if transpose else (D, K) ishape = (L, I) xshape = (L, 1, 1, K) wshape = (E, D, K) if transpose else (E, K, D) - indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) - x = mx.random.normal(xshape) / K**0.5 - w = mx.random.normal(wshape) / K**0.5 + indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype( + mx.uint32 + ) + x = mx.random.normal(xshape, key=k2) / K**0.5 + w = mx.random.normal(wshape, key=k3) / K**0.5 + + x = x.astype(dtype) + w = w.astype(dtype) + w, *wq = quantize( w, group_size=group_size, mode=mode, transpose=transpose ) @@ -875,13 +895,15 @@ class TestQuantized(mlx_tests.MLXTestCase): y3 = scatter_unsort(y3, inv_order, indices.shape) y4 = scatter_unsort(y4, inv_order, indices.shape) - self.assertLess((y1 - y2).abs().max(), 1e-5) - self.assertLess((y1 - y3).abs().max(), 1e-5) - self.assertLess((y1 - y4).abs().max(), 2e-4) + tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4 - self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) - self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) - self.assertTrue(mx.allclose(y1, y4, atol=2e-4)) + self.assertLess((y1 - y2).abs().max(), tol) + self.assertLess((y1 - y3).abs().max(), tol) + self.assertLess((y1 - y4).abs().max(), tol) + + self.assertTrue(mx.allclose(y1, y2, atol=tol)) + self.assertTrue(mx.allclose(y1, y3, atol=tol)) + self.assertTrue(mx.allclose(y1, y4, atol=tol)) def test_gather_qmm_grad(self): def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): @@ -905,10 +927,14 @@ class TestQuantized(mlx_tests.MLXTestCase): sorted_indices=sort, ) - x = mx.random.normal((16, 1, 256)) - w, s, b = mx.quantize(mx.random.normal((4, 256, 256))) - indices = mx.sort(mx.random.randint(0, 4, shape=(16,))) - cotan = mx.random.normal((16, 1, 256)) + key = mx.random.key(0) + k1, k2, k3, k4 = mx.random.split(key, 4) + dtype = mx.float32 + + x = mx.random.normal((16, 1, 256), key=k1).astype(dtype) + w, s, b = mx.quantize(mx.random.normal((4, 256, 256), key=k2).astype(dtype)) + indices = mx.sort(mx.random.randint(0, 4, shape=(16,), key=k3)) + cotan = mx.random.normal((16, 1, 256), key=k4).astype(dtype) (o1,), (dx1, ds1, db1) = mx.vjp( lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True), @@ -921,6 +947,7 @@ class TestQuantized(mlx_tests.MLXTestCase): [cotan], ) + self.assertLess((o1 - o2).abs().max(), 1e-4) self.assertTrue(mx.allclose(o1, o2, atol=1e-4)) self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4)) self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))