From 51449428ddfae13e4ad75ae8b72b509d17c6689c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 20 Aug 2025 14:05:35 -0700 Subject: [PATCH] speedup --- mlx/backend/metal/kernels/fp4_quantized.h | 222 +++++++++++++++------- python/tests/test_quantized.py | 2 + 2 files changed, 157 insertions(+), 67 deletions(-) diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp4_quantized.h index 0ce9bc35d..1c261b3c0 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.h +++ b/mlx/backend/metal/kernels/fp4_quantized.h @@ -1,4 +1,4 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2025 Apple Inc. #include #include @@ -24,6 +24,17 @@ inline constexpr short get_bytes_per_pack() { return wsize / 8; } +template +static inline T dequantize_scale(uint8_t s) { + using FOrI = union { + bfloat16_t f; + uint16_t i; + }; + FOrI out; + out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); + return static_cast(out.f); +} + template inline void load_vector(const device T* x, thread U* x_thread) { for (int i = 0; i < values_per_thread; i += 4) { @@ -48,7 +59,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { } } -constant float MXFP4_LUT[16] = { +constexpr constant static float MXFP4_LUT[16] = { +0.0f, +0.5f, +1.0f, @@ -66,51 +77,74 @@ constant float MXFP4_LUT[16] = { -4.0f, -6.0f}; -template -inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { - U accum = 0; +template +void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { + if (simd_gid == 0 && simd_lid < 16) { + lut[simd_lid] = static_cast(MXFP4_LUT[simd_lid]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut) { + U accum = 0; const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { accum += - (x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] + - x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] + - x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] + - x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]); + (x_thread[4 * i] * lut[ws[i] & 0x000f] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]); } return scale * accum; } -template -inline U -qdot_safe(const device uint8_t* w, const thread U* x_thread, S scale, int N) { +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut, + int N) { U accum = 0; const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { accum += - (x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] + - x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] + - x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] + - x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]); + (x_thread[4 * i] * lut[ws[i] & 0x000f] + + x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] + + x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] + + x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]); } return scale * accum; } template -inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) { +inline void qouter( + const thread uint8_t* w, + U x, + U scale, + thread U* result, + const threadgroup U* lut) { for (int i = 0; i < (values_per_thread / 2); i++) { - result[2 * i] += x * scale * MXFP4_LUT[w[i] & 0x0f]; - result[2 * i + 1] += x * scale * MXFP4_LUT[(w[i] & 0xf0) >> 4]; + result[2 * i] += x * scale * lut[w[i] & 0x0f]; + result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4]; } } template -inline void -dequantize(const device uint8_t* w, U scale, threadgroup U* w_local) { +inline void dequantize( + const device uint8_t* w, + U scale, + threadgroup U* w_local, + const threadgroup U* lut) { for (int i = 0; i < (N / 2); i++) { - w_local[2 * i] = scale * static_cast(MXFP4_LUT[w[i] & 0x0f]); - w_local[2 * i + 1] = scale * static_cast(MXFP4_LUT[(w[i] & 0xf0) >> 4]); + w_local[2 * i] = scale * lut[w[i] & 0x0f]; + w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4]; } } @@ -150,12 +184,14 @@ struct QuantizedBlockLoader { threadgroup T* dst; const device uint8_t* src; const device S* scales; + threadgroup T* lut; QuantizedBlockLoader( const device uint8_t* src_, const device S* scales_, const int src_ld_, threadgroup T* dst_, + threadgroup T* lut_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), @@ -170,17 +206,20 @@ struct QuantizedBlockLoader { dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), - scales(scales_ + bi * src_ld / group_size) {} + scales(scales_ + bi * src_ld / group_size), + lut(lut_) { + load_mxfp4_lut(lut, simd_group_id, simd_lane_id); + } void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } - T scale = metal::pow(T(2.0), static_cast(*scales) - 127); + T scale = dequantize_scale(*scales); for (int i = 0; i < n_reads; i++) { dequantize( - src + i * bytes_per_pack, scale, dst + i * pack_factor); + src + i * bytes_per_pack, scale, dst + i * pack_factor, lut); } } @@ -203,12 +242,13 @@ struct QuantizedBlockLoader { return; } - T scale = metal::pow(T(2.0), static_cast(*scales) - 127); + T scale = dequantize_scale(*scales); for (int i = 0; i < n_reads; i++) { dequantize( (device uint8_t*)(src + i * bytes_per_pack), scale, - dst + i * pack_factor); + dst + i * pack_factor, + lut); } } @@ -240,7 +280,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl( const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; constexpr int pack_factor = 8; constexpr int values_per_thread = D / QUAD_SIZE; @@ -252,6 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl( thread U x_thread[values_per_thread]; thread U result[results_per_quadgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; @@ -269,9 +313,9 @@ METAL_FUNC void mxfp4_qmv_quad_impl( auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); const device S* sl = scales + row * in_vec_size_g * quads_per_simd; - U s = metal::pow(2.0f, static_cast(sl[0]) - 127); + U s = dequantize_scale(sl[0]); if (row * quads_per_simd + out_row < out_vec_size) { - result[row] += qdot(wl, x_thread, s); + result[row] += qdot(wl, x_thread, s, lut); } } @@ -293,7 +337,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl( const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { constexpr int packs_per_thread = 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; @@ -306,9 +351,9 @@ METAL_FUNC void mxfp4_qmv_fast_impl( const device uint8_t* ws = (const device uint8_t*)w; typedef float U; - thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -328,8 +373,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl( auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; - U s = metal::pow(2.0f, static_cast(sl[0]) - 127); - result[row] += qdot(wl, x_thread, s); + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); } ws += block_size * bytes_per_pack / pack_factor; @@ -355,7 +400,8 @@ METAL_FUNC void mxfp4_qmv_impl( const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; @@ -372,6 +418,7 @@ METAL_FUNC void mxfp4_qmv_impl( thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -402,7 +449,7 @@ METAL_FUNC void mxfp4_qmv_impl( const device auto* sl = scales + row * in_vec_size_g; S s = sl[0]; - result[row] += qdot(wl, x_thread, s); + result[row] += qdot(wl, x_thread, s, lut); } ws += block_size * bytes_per_pack / pack_factor; @@ -420,8 +467,8 @@ METAL_FUNC void mxfp4_qmv_impl( auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; - U s = metal::pow(2.0f, static_cast(sl[0]) - 127); - result[row] += qdot(wl, x_thread, s); + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); } } @@ -449,8 +496,8 @@ METAL_FUNC void mxfp4_qmv_impl( auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; - U s = metal::pow(2.0f, static_cast(sl[0]) - 127); - result[row] += qdot(wl, x_thread, s); + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); } ws += block_size * bytes_per_pack / pack_factor; @@ -468,9 +515,9 @@ METAL_FUNC void mxfp4_qmv_impl( auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; - U s = metal::pow(2.0f, static_cast(sl[0]) - 127); + U s = dequantize_scale(sl[0]); result[row] += - qdot_safe(wl, x_thread, s, remaining); + qdot_safe(wl, x_thread, s, lut, remaining); } } for (int row = 0; row < results_per_simdgroup; row++) { @@ -492,7 +539,8 @@ METAL_FUNC void mxfp4_qvm_impl( const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { constexpr int num_simdgroups = 2; constexpr int pack_factor = get_pack_factor<32>(); constexpr int bytes_per_pack = get_bytes_per_pack(); @@ -513,6 +561,8 @@ METAL_FUNC void mxfp4_qvm_impl( thread U scale = 0; thread U x_local = 0; + load_mxfp4_lut(lut, simd_gid, simd_lid); + // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; @@ -531,10 +581,10 @@ METAL_FUNC void mxfp4_qvm_impl( if (remaining == 0) { for (int i = 0; i < in_vec_size; i += block_size) { x_local = *x; - scale = metal::pow(2.0f, static_cast(*scales) - 127); + scale = dequantize_scale(*scales); w_local = *((device vec_w*)ws); qouter( - (thread uint8_t*)&w_local, x_local, scale, result); + (thread uint8_t*)&w_local, x_local, scale, result, lut); x += block_size; scales += block_size * out_vec_size_g; @@ -543,11 +593,11 @@ METAL_FUNC void mxfp4_qvm_impl( } else { for (int i = block_size; i < in_vec_size; i += block_size) { x_local = *x; - scale = metal::pow(2.0f, static_cast(*scales) - 127); + scale = dequantize_scale(*scales); w_local = *((device vec_w*)ws); qouter( - (thread uint8_t*)&w_local, x_local, scale, result); + (thread uint8_t*)&w_local, x_local, scale, result, lut); x += block_size; scales += block_size * out_vec_size_g; @@ -555,14 +605,14 @@ METAL_FUNC void mxfp4_qvm_impl( } if (static_cast(simd_lid) < remaining) { x_local = *x; - scale = metal::pow(2.0f, static_cast(*scales) - 127); + scale = dequantize_scale(*scales); w_local = *((device vec_w*)ws); } else { x_local = 0; scale = 0; } qouter( - (thread uint8_t*)&w_local, x_local, scale, result); + (thread uint8_t*)&w_local, x_local, scale, result, lut); } // Accumulate in the simdgroup @@ -601,7 +651,8 @@ METAL_FUNC void mxfp4_qmm_t_impl( uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); @@ -646,7 +697,7 @@ METAL_FUNC void mxfp4_qmm_t_impl( const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { @@ -725,7 +776,8 @@ METAL_FUNC void mxfp4_qmm_n_impl( uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); @@ -767,7 +819,7 @@ METAL_FUNC void mxfp4_qmm_n_impl( // Make the x loader and mma operation const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { @@ -941,7 +993,9 @@ template const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( @@ -959,8 +1013,20 @@ template s_strides, tid); } + threadgroup float lut[16]; mxfp4_qmv_quad_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); + w, + scales, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid, + simd_gid, + simd_lid, + lut); } template @@ -998,8 +1064,9 @@ template s_strides, tid); } + threadgroup float lut[16]; mxfp4_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template @@ -1037,8 +1104,9 @@ template s_strides, tid); } + threadgroup float lut[16]; mxfp4_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template @@ -1076,8 +1144,9 @@ template s_strides, tid); } + threadgroup float lut[16]; mxfp4_qvm_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template @@ -1119,8 +1188,18 @@ template int in_vec_size_adj = tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + threadgroup float lut[16]; mxfp4_qvm_impl( - w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); + w, + scales, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid, + lut); } template < @@ -1157,6 +1236,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; if (batched) { adjust_matrix_offsets( @@ -1175,7 +1255,7 @@ template < tid); } mxfp4_qmm_t_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < @@ -1212,6 +1292,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; if (batched) { adjust_matrix_offsets( @@ -1231,7 +1312,7 @@ template < } mxfp4_qmm_n_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template @@ -1279,8 +1360,9 @@ template w_strides, s_strides, tid); + threadgroup float lut[16]; mxfp4_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template @@ -1328,8 +1410,9 @@ template w_strides, s_strides, tid); + threadgroup float lut[16]; mxfp4_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template @@ -1377,8 +1460,9 @@ template w_strides, s_strides, tid); + threadgroup float lut[16]; mxfp4_qvm_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template < @@ -1420,6 +1504,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; adjust_matrix_offsets( x, @@ -1442,7 +1527,7 @@ template < s_strides, tid); mxfp4_qmm_t_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < @@ -1484,6 +1569,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; adjust_matrix_offsets( x, @@ -1506,7 +1592,7 @@ template < s_strides, tid); mxfp4_qmm_n_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template @@ -1621,6 +1707,7 @@ template < constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T lut[16]; using mma_t = mlx::steel::BlockMMA< T, @@ -1709,6 +1796,7 @@ template < scales + index * stride_s, transpose ? K : N, Ws, + lut, simd_group_id, simd_lane_id); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index f792c8c11..e67d4922f 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -734,6 +734,8 @@ class TestQuantized(mlx_tests.MLXTestCase): for L, K, D, E, I, transpose, mode in parameters: if mode == "mxfp4": group_size = 32 + else: + group_size = 64 K, D = (K, D) if transpose else (D, K) ishape = (L, I) xshape = (L, 1, 1, K)