From 39b04ce6382e4bebff7583c3ef64c9182fb701b5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 31 Oct 2025 11:49:59 -0700 Subject: [PATCH] use faster dequant for fp4 qmv (#2720) --- mlx/backend/metal/kernels/fp4.h | 5 +- mlx/backend/metal/kernels/fp8.h | 24 ++- mlx/backend/metal/kernels/fp_quantized.h | 191 ++++++++--------------- 3 files changed, 76 insertions(+), 144 deletions(-) diff --git a/mlx/backend/metal/kernels/fp4.h b/mlx/backend/metal/kernels/fp4.h index 40742cc31..e701adc5d 100644 --- a/mlx/backend/metal/kernels/fp4.h +++ b/mlx/backend/metal/kernels/fp4.h @@ -49,7 +49,10 @@ struct fp4_e2m1 { } operator float() { - return FP4_LUT[bits]; + half converted = as_type(ushort((bits & 7) << 9)); + converted *= 16384.0; + converted = bits & 8 ? -converted : converted; + return converted; } uint8_t bits; diff --git a/mlx/backend/metal/kernels/fp8.h b/mlx/backend/metal/kernels/fp8.h index 4b1836a39..34816b42b 100644 --- a/mlx/backend/metal/kernels/fp8.h +++ b/mlx/backend/metal/kernels/fp8.h @@ -1,12 +1,5 @@ #pragma once -inline float fp32_from_bits(uint32_t bits) { - return *(reinterpret_cast(&bits)); -} -inline float fp32_to_bits(float x) { - return *(reinterpret_cast(&x)); -} - struct fp8_e4m3 { template fp8_e4m3(T f) { @@ -14,7 +7,7 @@ struct fp8_e4m3 { // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 uint32_t fp8_max = 543 << 21; uint32_t denorm_mask = 141 << 23; - uint32_t f_bits = fp32_to_bits(static_cast(f)); + uint32_t f_bits = as_type(static_cast(f)); uint32_t sign = f_bits & 0x80000000; f_bits ^= sign; if (f_bits >= fp8_max) { @@ -22,8 +15,8 @@ struct fp8_e4m3 { bits = 0x7E; } else { if (f_bits < (121 << 23)) { - f_bits = - fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + f_bits = as_type( + as_type(f_bits) + as_type(denorm_mask)); bits = static_cast(f_bits - denorm_mask); } else { // resulting mantissa is odd @@ -53,7 +46,7 @@ struct fp8_e4m3 { ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask); - return fp32_from_bits(result); + return as_type(result); } uint8_t bits; @@ -77,11 +70,12 @@ struct fp8_e8m0 { bits = static_cast(n + 127); } + operator bfloat16_t() { + uint16_t out = (bits == 0 ? 0x40 : (static_cast(bits) << 7)); + return as_type(out); + } operator float() { - if (bits == 0xFF) { - return metal::numeric_limits::quiet_NaN(); - } - return metal::ldexp(1.0f, static_cast(bits) - 127); + return static_cast(this->operator bfloat16_t()); } uint8_t bits; diff --git a/mlx/backend/metal/kernels/fp_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h index 38e4c3a73..cae1bbd9e 100644 --- a/mlx/backend/metal/kernels/fp_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -29,15 +29,31 @@ inline constexpr short get_bytes_per_pack() { template static inline T dequantize_scale(uint8_t s) { - using FOrI = union { - bfloat16_t f; - uint16_t i; - }; - FOrI out; - out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); - return static_cast(out.f); + return T(*(thread fp8_e8m0*)(&s)); } +template +struct Quantize { + uint8_t operator()(float x) { + if (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + float operator()(uint8_t x) { + if (bits == 8) { + return float(*(thread fp8_e4m3*)(&x)); + } else { + return float(*(thread fp4_e2m1*)(&x)); + } + } +}; + template inline void load_vector(const device T* x, thread U* x_thread) { for (int i = 0; i < values_per_thread; i += 4) { @@ -62,62 +78,41 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { } } -template -void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { - if (simd_gid == 0 && simd_lid < 16) { - lut[simd_lid] = static_cast(FP4_LUT[simd_lid]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); -} - template -inline U qdot( - const device uint8_t* w, - const thread U* x_thread, - U scale, - const threadgroup U* lut) { +inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { U accum = 0; const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { accum += - (x_thread[4 * i] * lut[ws[i] & 0xf] + - x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + - x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + - x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); } return scale * accum; } template -inline U qdot_safe( - const device uint8_t* w, - const thread U* x_thread, - U scale, - const threadgroup U* lut, - int N) { +inline U +qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) { U accum = 0; const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { accum += - (x_thread[4 * i] * lut[ws[i] & 0xf] + - x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + - x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + - x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); } return scale * accum; } template -inline void qouter( - const thread uint8_t* w, - U x, - U scale, - thread U* result, - const threadgroup U* lut) { +inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) { for (int i = 0; i < (values_per_thread / 2); i++) { - result[2 * i] += x * scale * lut[w[i] & 0xf]; - result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf]; + result[2 * i] += x * scale * Dequantize<4>{}(w[i]); + result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4); } } @@ -192,7 +187,10 @@ struct QuantizedBlockLoader { bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), lut(lut_) { - load_fp4_lut(lut, simd_group_id, simd_lane_id); + if (simd_group_id == 0 && simd_lane_id < 16) { + lut[simd_lane_id] = static_cast(FP4_LUT[simd_lane_id]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); } void load_unsafe() const { @@ -264,10 +262,7 @@ METAL_FUNC void fp_qmv_quad_impl( const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]], - threadgroup float* lut) { + uint quad_lid [[thread_index_in_quadgroup]]) { constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; constexpr int pack_factor = 8; constexpr int values_per_thread = D / QUAD_SIZE; @@ -279,7 +274,6 @@ METAL_FUNC void fp_qmv_quad_impl( thread U x_thread[values_per_thread]; thread U result[results_per_quadgroup] = {0}; - load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; @@ -299,7 +293,7 @@ METAL_FUNC void fp_qmv_quad_impl( U s = dequantize_scale(sl[0]); if (row * quads_per_simd + out_row < out_vec_size) { - result[row] += qdot(wl, x_thread, s, lut); + result[row] += qdot(wl, x_thread, s); } } @@ -321,8 +315,7 @@ METAL_FUNC void fp_qmv_fast_impl( const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]], - threadgroup float* lut) { + uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int packs_per_thread = 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; @@ -337,7 +330,6 @@ METAL_FUNC void fp_qmv_fast_impl( typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -358,7 +350,7 @@ METAL_FUNC void fp_qmv_fast_impl( const device auto* sl = scales + row * in_vec_size_g; U s = dequantize_scale(sl[0]); - result[row] += qdot(wl, x_thread, s, lut); + result[row] += qdot(wl, x_thread, s); } ws += block_size * bytes_per_pack / pack_factor; @@ -384,8 +376,7 @@ METAL_FUNC void fp_qmv_impl( const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]], - threadgroup float* lut) { + uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; @@ -402,7 +393,6 @@ METAL_FUNC void fp_qmv_impl( thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -433,7 +423,7 @@ METAL_FUNC void fp_qmv_impl( const device auto* sl = scales + row * in_vec_size_g; uint8_t s = sl[0]; - result[row] += qdot(wl, x_thread, s, lut); + result[row] += qdot(wl, x_thread, s); } ws += block_size * bytes_per_pack / pack_factor; @@ -452,7 +442,7 @@ METAL_FUNC void fp_qmv_impl( const device auto* sl = scales + row * in_vec_size_g; U s = dequantize_scale(sl[0]); - result[row] += qdot(wl, x_thread, s, lut); + result[row] += qdot(wl, x_thread, s); } } @@ -481,7 +471,7 @@ METAL_FUNC void fp_qmv_impl( const device auto* sl = scales + row * in_vec_size_g; U s = dequantize_scale(sl[0]); - result[row] += qdot(wl, x_thread, s, lut); + result[row] += qdot(wl, x_thread, s); } ws += block_size * bytes_per_pack / pack_factor; @@ -501,7 +491,7 @@ METAL_FUNC void fp_qmv_impl( U s = dequantize_scale(sl[0]); result[row] += - qdot_safe(wl, x_thread, s, lut, remaining); + qdot_safe(wl, x_thread, s, remaining); } } for (int row = 0; row < results_per_simdgroup; row++) { @@ -523,8 +513,7 @@ METAL_FUNC void fp_qvm_impl( const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]], - threadgroup float* lut) { + uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int num_simdgroups = 2; constexpr int pack_factor = get_pack_factor<32>(); constexpr int bytes_per_pack = get_bytes_per_pack(); @@ -545,8 +534,6 @@ METAL_FUNC void fp_qvm_impl( thread U scale = 0; thread U x_local = 0; - load_fp4_lut(lut, simd_gid, simd_lid); - // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; @@ -568,7 +555,7 @@ METAL_FUNC void fp_qvm_impl( scale = dequantize_scale(*scales); w_local = *((device vec_w*)ws); qouter( - (thread uint8_t*)&w_local, x_local, scale, result, lut); + (thread uint8_t*)&w_local, x_local, scale, result); x += block_size; scales += block_size * out_vec_size_g; @@ -581,7 +568,7 @@ METAL_FUNC void fp_qvm_impl( w_local = *((device vec_w*)ws); qouter( - (thread uint8_t*)&w_local, x_local, scale, result, lut); + (thread uint8_t*)&w_local, x_local, scale, result); x += block_size; scales += block_size * out_vec_size_g; @@ -596,7 +583,7 @@ METAL_FUNC void fp_qvm_impl( scale = 0; } qouter( - (thread uint8_t*)&w_local, x_local, scale, result, lut); + (thread uint8_t*)&w_local, x_local, scale, result); } // Accumulate in the simdgroup @@ -975,9 +962,7 @@ template const constant int64_t* s_strides, uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { + uint quad_lid [[thread_index_in_quadgroup]]) { if (batched) { int M = x_shape[x_batch_ndims]; adjust_matrix_offsets( @@ -995,20 +980,8 @@ template s_strides, tid); } - threadgroup float lut[16]; fp_qmv_quad_impl( - w, - scales, - x, - y, - in_vec_size, - out_vec_size, - tid, - quad_gid, - quad_lid, - simd_gid, - simd_lid, - lut); + w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); } template @@ -1046,9 +1019,8 @@ template s_strides, tid); } - threadgroup float lut[16]; fp_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template @@ -1086,9 +1058,8 @@ template s_strides, tid); } - threadgroup float lut[16]; fp_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template @@ -1126,9 +1097,8 @@ template s_strides, tid); } - threadgroup float lut[16]; fp_qvm_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template @@ -1170,18 +1140,8 @@ template int in_vec_size_adj = tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; - threadgroup float lut[16]; fp_qvm_impl( - w, - scales, - x, - y, - in_vec_size_adj, - out_vec_size, - tid, - simd_gid, - simd_lid, - lut); + w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); } template < @@ -1342,9 +1302,8 @@ template w_strides, s_strides, tid); - threadgroup float lut[16]; fp_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template @@ -1392,9 +1351,8 @@ template w_strides, s_strides, tid); - threadgroup float lut[16]; fp_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template @@ -1442,9 +1400,8 @@ template w_strides, s_strides, tid); - threadgroup float lut[16]; fp_qvm_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } template < @@ -1771,28 +1728,6 @@ template < } } -template -struct Quantize { - uint8_t operator()(float x) { - if (bits == 8) { - return fp8_e4m3(x).bits; - } else { - return fp4_e2m1(x).bits; - } - } -}; - -template -struct Dequantize { - float operator()(uint8_t x) { - if (bits == 8) { - return float(*(thread fp8_e4m3*)(&x)); - } else { - return float(*(thread fp4_e2m1*)(&x)); - } - } -}; - template [[kernel]] void fp_quantize( const device T* w [[buffer(0)]],