From 166dfac5cf7d35e19b1add8b50332e3053dc4310 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 19 Nov 2025 12:00:59 -0800 Subject: [PATCH] Add group size 32 int qmms --- mlx/backend/metal/kernels/quantized_nax.h | 142 +++++++++++++++++- mlx/backend/metal/kernels/quantized_nax.metal | 3 +- mlx/backend/metal/quantized.cpp | 6 +- 3 files changed, 146 insertions(+), 5 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h index ef0b8e368..0f07e4fc8 100644 --- a/mlx/backend/metal/kernels/quantized_nax.h +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -692,6 +692,146 @@ struct QuantizedBlockLoader { } }; +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short bits> +struct QuantizedBlockLoader< + T, + BROWS, + BCOLS, + dst_ld, + reduction_dim, + tgp_size, + 32, + bits> { + MLX_MTL_CONST short group_size = 32; + + static_assert( + BCOLS % group_size == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short n_groups = BCOLS / group_size; + + static_assert( + (BCOLS_PACKED / n_reads) == n_groups, + "Other configurations are not yet supported"); + + const int src_ld; + const int tile_stride; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + const short group_id; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + group_id((bj * pack_factor) / group_size), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size + group_id), + biases(biases_ + bi * src_ld / group_size + group_id) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + // if (group_steps > 1) { + // group_step_cnt++; + // if (group_step_cnt == group_steps) { + // group_step_cnt = 0; + // scales++; + // biases++; + // } + // } else { + scales += n_groups; + biases += n_groups; + // } + } else { + scales += n_groups * group_stride; + biases += n_groups * group_stride; + } + } +}; + template METAL_FUNC void adjust_matrix_offsets( const device T*& x, @@ -843,7 +983,7 @@ METAL_FUNC void qmm_t_nax_tgp_impl( biases += y_col * K_g; y += y_row * static_cast(N) + y_col; - // Make the x loader and mma operation + // Make the weight loader loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); constexpr short UM = 16; diff --git a/mlx/backend/metal/kernels/quantized_nax.metal b/mlx/backend/metal/kernels/quantized_nax.metal index 98703a608..5a9c9fb87 100644 --- a/mlx/backend/metal/kernels/quantized_nax.metal +++ b/mlx/backend/metal/kernels/quantized_nax.metal @@ -92,7 +92,8 @@ #define instantiate_quantized_groups(bits) \ instantiate_quantized_types(128, bits) \ - instantiate_quantized_types(64, bits) + instantiate_quantized_types(64, bits) \ + instantiate_quantized_types(32, bits) #define instantiate_quantized_all() \ instantiate_quantized_groups(2) \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 0594c02f2..4906fa748 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -675,7 +675,7 @@ void qmm( if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && transpose && (x.dtype() != float32 || env::enable_tf32()) && mode == "affine" && - (group_size >= 64) && (K % 64 == 0)) { + (K % 64 == 0)) { return qmm_nax( /* const array& x = */ x, /* const array& w = */ w, @@ -777,8 +777,8 @@ void gather_qmm( if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { if (metal::is_nax_available() && transpose && - (x.dtype() != float32 || env::enable_tf32()) && transpose && - mode == "affine" && (group_size >= 64) && (K % 64 == 0)) { + (x.dtype() != float32 || env::enable_tf32()) && mode == "affine" && + (K % 64 == 0)) { return gather_qmm_nax( /* const array& x = */ x, /* const array& w = */ w,