From 2f3b5c9c5c34e7dbfa92c0c92b7c6a9a8c13d8ff Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 24 Oct 2025 10:56:11 -0700 Subject: [PATCH] fix jit --- mlx/backend/metal/CMakeLists.txt | 3 +- mlx/backend/metal/jit/includes.h | 2 +- mlx/backend/metal/jit_kernels.cpp | 51 ++--- mlx/backend/metal/kernels/CMakeLists.txt | 3 +- mlx/backend/metal/kernels/fp_quantized.h | 185 +++++++++---------- mlx/backend/metal/kernels/fp_quantized.metal | 112 ++++++----- mlx/backend/metal/quantized.cpp | 11 +- 7 files changed, 172 insertions(+), 195 deletions(-) diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 628a35ae9..0fd1834f6 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -81,7 +81,8 @@ if(MLX_METAL_JIT) make_jit_source(quantized_utils) make_jit_source(quantized kernels/quantized_utils.h) - make_jit_source(fp4_quantized kernels/quantized_utils.h) + make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h + kernels/fp4.h) make_jit_source(gemv_masked) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index f3b57c7f9..c12e576c1 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -24,7 +24,7 @@ const char* hadamard(); const char* logsumexp(); const char* quantized_utils(); const char* quantized(); -const char* fp4_quantized(); +const char* fp_quantized(); const char* ternary(); const char* scan(); const char* scatter_axis(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index e70420cf8..de391abc9 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel( metal::utils(), metal::gemm(), metal::quantized_utils(), - (mode == "affine") ? metal::quantized() : metal::fp4_quantized(), + (mode == "affine") ? metal::quantized() : metal::fp_quantized(), template_def); return kernel_source; }); @@ -856,39 +856,22 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm()); - if (mode == "affine") { - concatenate( - kernel_source, - metal::quantized(), - get_template_definition( - lib_name, - mode + "_gather_qmm_rhs", - get_type_string(x.dtype()), - group_size, - bits, - bm, - bn, - bk, - wm, - wn, - transpose)); - } else { - concatenate( - kernel_source, - metal::fp4_quantized(), - get_template_definition( - lib_name, - mode + "_gather_qmm_rhs", - get_type_string(x.dtype()), - group_size, - "uint8_t", - bm, - bn, - bk, - wm, - wn, - transpose)); - } + bool is_affine = mode == "affine"; + concatenate( + kernel_source, + is_affine ? metal::quantized() : metal::fp_quantized(), + get_template_definition( + lib_name, + (is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"), + get_type_string(x.dtype()), + group_size, + bits, + bm, + bn, + bk, + wm, + wn, + transpose)); return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index cd743f0a8..69ac2a5e9 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -110,7 +110,8 @@ if(NOT MLX_METAL_JIT) reduction/reduce_col.h reduction/reduce_row.h) build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) - build_kernel(fp_quantized fp_quantized.h quantized_utils.h ${STEEL_HEADERS}) + build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h + ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) diff --git a/mlx/backend/metal/kernels/fp_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h index 2d2809bb7..38e4c3a73 100644 --- a/mlx/backend/metal/kernels/fp_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -3,6 +3,9 @@ #include #include +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + constant bool align_M [[function_constant(200)]]; constant bool align_N [[function_constant(201)]]; constant bool align_K [[function_constant(202)]]; @@ -60,7 +63,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { } template -void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { +void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { if (simd_gid == 0 && simd_lid < 16) { lut[simd_lid] = static_cast(FP4_LUT[simd_lid]); } @@ -137,8 +140,7 @@ template < short dst_ld, short reduction_dim, short tgp_size, - short group_size, - typename S> + short group_size> struct QuantizedBlockLoader { static_assert( BCOLS <= group_size, @@ -165,12 +167,12 @@ struct QuantizedBlockLoader { threadgroup T* dst; const device uint8_t* src; - const device S* scales; + const device uint8_t* scales; threadgroup T* lut; QuantizedBlockLoader( const device uint8_t* src_, - const device S* scales_, + const device uint8_t* scales_, const int src_ld_, threadgroup T* dst_, threadgroup T* lut_, @@ -190,7 +192,7 @@ struct QuantizedBlockLoader { bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), lut(lut_) { - load_mxfp4_lut(lut, simd_group_id, simd_lane_id); + load_fp4_lut(lut, simd_group_id, simd_lane_id); } void load_unsafe() const { @@ -252,10 +254,10 @@ struct QuantizedBlockLoader { } }; -template -METAL_FUNC void mxfp4_qmv_quad_impl( +template +METAL_FUNC void fp_qmv_quad_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, constant int& in_vec_size, @@ -277,7 +279,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl( thread U x_thread[values_per_thread]; thread U result[results_per_quadgroup] = {0}; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; @@ -293,7 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl( for (int row = 0; row < results_per_quadgroup; row++) { auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); - const device S* sl = scales + row * in_vec_size_g * quads_per_simd; + const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd; U s = dequantize_scale(sl[0]); if (row * quads_per_simd + out_row < out_vec_size) { @@ -309,10 +311,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl( } } -template -METAL_FUNC void mxfp4_qmv_fast_impl( +template +METAL_FUNC void fp_qmv_fast_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -335,7 +337,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl( typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -372,10 +374,10 @@ METAL_FUNC void mxfp4_qmv_fast_impl( } } -template -METAL_FUNC void mxfp4_qmv_impl( +template +METAL_FUNC void fp_qmv_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -400,7 +402,7 @@ METAL_FUNC void mxfp4_qmv_impl( thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; @@ -430,7 +432,7 @@ METAL_FUNC void mxfp4_qmv_impl( auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device auto* sl = scales + row * in_vec_size_g; - S s = sl[0]; + uint8_t s = sl[0]; result[row] += qdot(wl, x_thread, s, lut); } @@ -511,10 +513,10 @@ METAL_FUNC void mxfp4_qmv_impl( } } -template -METAL_FUNC void mxfp4_qvm_impl( +template +METAL_FUNC void fp_qvm_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const int in_vec_size, @@ -543,7 +545,7 @@ METAL_FUNC void mxfp4_qvm_impl( thread U scale = 0; thread U x_local = 0; - load_mxfp4_lut(lut, simd_gid, simd_lid); + load_fp4_lut(lut, simd_gid, simd_lid); // Adjust positions const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; @@ -615,14 +617,14 @@ METAL_FUNC void mxfp4_qvm_impl( template < typename T, const int group_size, - typename S, + const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> -METAL_FUNC void mxfp4_qmm_t_impl( +METAL_FUNC void fp_qmm_t_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, threadgroup T* Xs, @@ -659,8 +661,7 @@ METAL_FUNC void mxfp4_qmm_t_impl( BK_padded, 1, WM * WN * SIMD_SIZE, - group_size, - S>; + group_size>; // Set the block const int K_w = K * bytes_per_pack / pack_factor; @@ -741,13 +742,13 @@ METAL_FUNC void mxfp4_qmm_t_impl( template < typename T, const int group_size, - typename S, + const int bits, const int BM = 32, const int BK = 32, const int BN = 32> -METAL_FUNC void mxfp4_qmm_n_impl( +METAL_FUNC void fp_qmm_n_impl( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, threadgroup T* Xs, @@ -785,8 +786,7 @@ METAL_FUNC void mxfp4_qmm_n_impl( BN_padded, 0, WM * WN * SIMD_SIZE, - group_size, - S>; + group_size>; auto wl = (const device uint8_t*)w; @@ -873,11 +873,11 @@ METAL_FUNC void mxfp4_qmm_n_impl( } } -template +template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, - const device S*& scales, + const device uint8_t*& scales, device T*& y, int output_stride, const constant int& x_batch_ndims, @@ -908,11 +908,11 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template +template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, - const device S*& scales, + const device uint8_t*& scales, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, device T*& y, @@ -958,10 +958,10 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template -[[kernel]] void mxfp4_qmv_quad( +template +[[kernel]] void fp_qmv_quad( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -996,7 +996,7 @@ template tid); } threadgroup float lut[16]; - mxfp4_qmv_quad_impl( + fp_qmv_quad_impl( w, scales, x, @@ -1011,10 +1011,10 @@ template lut); } -template -[[kernel]] void mxfp4_qmv_fast( +template +[[kernel]] void fp_qmv_fast( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1047,14 +1047,14 @@ template tid); } threadgroup float lut[16]; - mxfp4_qmv_fast_impl( + fp_qmv_fast_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_qmv( +template +[[kernel]] void fp_qmv( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1087,14 +1087,14 @@ template tid); } threadgroup float lut[16]; - mxfp4_qmv_impl( + fp_qmv_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_qvm( +template +[[kernel]] void fp_qvm( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1127,14 +1127,14 @@ template tid); } threadgroup float lut[16]; - mxfp4_qvm_impl( + fp_qvm_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_qvm_split_k( +template +[[kernel]] void fp_qvm_split_k( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& in_vec_size, @@ -1171,7 +1171,7 @@ template tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; threadgroup float lut[16]; - mxfp4_qvm_impl( + fp_qvm_impl( w, scales, x, @@ -1187,15 +1187,15 @@ template template < typename T, const int group_size, - typename S, + const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_qmm_t( +[[kernel]] void fp_qmm_t( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& K, @@ -1236,21 +1236,21 @@ template < s_strides, tid); } - mxfp4_qmm_t_impl( + fp_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < typename T, const int group_size, - typename S, + const int bits, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_qmm_n( +[[kernel]] void fp_qmm_n( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, device T* y, const constant int& K, @@ -1293,14 +1293,14 @@ template < tid); } - mxfp4_qmm_n_impl( + fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_gather_qmv_fast( +template +[[kernel]] void fp_gather_qmv_fast( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1343,14 +1343,14 @@ template s_strides, tid); threadgroup float lut[16]; - mxfp4_qmv_fast_impl( + fp_qmv_fast_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_gather_qmv( +template +[[kernel]] void fp_gather_qmv( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1393,14 +1393,14 @@ template s_strides, tid); threadgroup float lut[16]; - mxfp4_qmv_impl( + fp_qmv_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template -[[kernel]] void mxfp4_gather_qvm( +template +[[kernel]] void fp_gather_qvm( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1443,21 +1443,21 @@ template s_strides, tid); threadgroup float lut[16]; - mxfp4_qvm_impl( + fp_qvm_impl( w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } template < typename T, const int group_size, - typename S, + const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_gather_qmm_t( +[[kernel]] void fp_gather_qmm_t( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1508,20 +1508,20 @@ template < w_strides, s_strides, tid); - mxfp4_qmm_t_impl( + fp_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < typename T, const int group_size, - typename S, + const int bits, const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void mxfp4_gather_qmm_n( +[[kernel]] void fp_gather_qmm_n( const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device T* x, const device uint32_t* lhs_indices, const device uint32_t* rhs_indices, @@ -1573,24 +1573,24 @@ template < w_strides, s_strides, tid); - mxfp4_qmm_n_impl( + fp_qmm_n_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < typename T, int group_size, - typename S, + int bits, int BM, int BN, int BK, int WM, int WN, bool transpose> -[[kernel]] void mxfp4_gather_qmm_rhs( +[[kernel]] void fp_gather_qmm_rhs( const device T* x, const device uint32_t* w, - const device S* scales, + const device uint8_t* scales, const device uint32_t* indices, device T* y, const constant int& M, @@ -1626,8 +1626,7 @@ template < transpose ? BK_padded : BN_padded, transpose, WM * WN * SIMD_SIZE, - group_size, - S>; + group_size>; threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; @@ -1775,7 +1774,7 @@ template < template struct Quantize { uint8_t operator()(float x) { - if constexpr (bits == 8) { + if (bits == 8) { return fp8_e4m3(x).bits; } else { return fp4_e2m1(x).bits; @@ -1786,7 +1785,7 @@ struct Quantize { template struct Dequantize { float operator()(uint8_t x) { - if constexpr (bits == 8) { + if (bits == 8) { return float(*(thread fp8_e4m3*)(&x)); } else { return float(*(thread fp4_e2m1*)(&x)); diff --git a/mlx/backend/metal/kernels/fp_quantized.metal b/mlx/backend/metal/kernels/fp_quantized.metal index bef26b786..091174d91 100644 --- a/mlx/backend/metal/kernels/fp_quantized.metal +++ b/mlx/backend/metal/kernels/fp_quantized.metal @@ -4,63 +4,61 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/quantized_utils.h" -#include "mlx/backend/metal/kernels/fp8.h" -#include "mlx/backend/metal/kernels/fp4.h" #include "mlx/backend/metal/kernels/fp_quantized.h" -#define instantiate_quantized(name, type) \ +#define instantiate_quantized(mode, name, type) \ instantiate_kernel( \ - #name "_" #type "_gs_32_b_4", \ - name, \ + #mode "_" #name "_" #type "_gs_32_b_4", \ + fp_ ## name, \ type, \ 32, \ - uint8_t) + 4) -#define instantiate_quantized_batched(name, type, batched) \ +#define instantiate_quantized_batched(mode, name, type, batched) \ instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_batch_" #batched, \ - name, \ + #mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \ + fp_ ## name, \ type, \ 32, \ - uint8_t, \ + 4, \ batched) -#define instantiate_quantized_aligned(name, type, aligned) \ +#define instantiate_quantized_aligned(mode, name, type, aligned) \ instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_alN_" #aligned, \ - name, \ + #mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \ + fp_ ## name, \ type, \ 32, \ - uint8_t, \ + 4, \ aligned) -#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \ +#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \ instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \ - name, \ + #mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \ + fp_ ## name, \ type, \ 32, \ - uint8_t, \ + 4, \ aligned, \ batched) -#define instantiate_quantized_quad(name, type, D, batched) \ +#define instantiate_quantized_quad(mode, name, type, D, batched) \ instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \ - name, \ + #mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \ + fp_ ## name, \ type, \ 32, \ - uint8_t, \ + 4, \ D, \ batched) -#define instantiate_quantized_split_k(name, type, split_k) \ +#define instantiate_quantized_split_k(mode, name, type, split_k) \ instantiate_kernel( \ - #name "_" #type "_gs_32_b_4_spk_" #split_k, \ - name, \ + #mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \ + fp_ ## name, \ type, \ 32, \ - uint8_t, \ + 4, \ split_k) #define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ @@ -69,7 +67,7 @@ func, \ type, \ 32, \ - uint8_t, \ + 4, \ bm, \ bn, \ bk, \ @@ -77,62 +75,62 @@ wn, \ transpose) -#define instantiate_quantized_batched_wrap(name, type) \ - instantiate_quantized_batched(name, type, 1) \ - instantiate_quantized_batched(name, type, 0) +#define instantiate_quantized_batched_wrap(mode, name, type) \ + instantiate_quantized_batched(mode, name, type, 1) \ + instantiate_quantized_batched(mode, name, type, 0) #define instantiate_quantized_all_batched(type) \ - instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \ - instantiate_quantized_batched_wrap(mxfp4_qmv, type) \ - instantiate_quantized_batched_wrap(mxfp4_qvm, type) \ - instantiate_quantized_batched_wrap(mxfp4_qmm_n, type) + instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \ + instantiate_quantized_batched_wrap(mxfp4, qmv, type) \ + instantiate_quantized_batched_wrap(mxfp4, qvm, type) \ + instantiate_quantized_batched_wrap(mxfp4, qmm_n, type) #define instantiate_quantized_all_single(type) \ - instantiate_quantized(mxfp4_gather_qmv_fast, type) \ - instantiate_quantized(mxfp4_gather_qmv, type) \ - instantiate_quantized(mxfp4_gather_qvm, type) \ - instantiate_quantized(mxfp4_gather_qmm_n, type) + instantiate_quantized(mxfp4, gather_qmv_fast, type) \ + instantiate_quantized(mxfp4, gather_qmv, type) \ + instantiate_quantized(mxfp4, gather_qvm, type) \ + instantiate_quantized(mxfp4, gather_qmm_n, type) #define instantiate_quantized_all_aligned(type) \ - instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \ - instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \ - instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0) + instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \ + instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0) #define instantiate_quantized_all_quad(type) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \ - instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0) + instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \ + instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \ + instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \ + instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0) #define instantiate_quantized_all_splitk(type) \ - instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \ - instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32) + instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \ + instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32) #define instantiate_quantized_all_rhs(type) \ - instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ - instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) + instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) #define instantiate_quantize_dequantize(type, mode, group_size, bits) \ instantiate_kernel( \ - mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \ + #mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \ fp_quantize, \ type, \ group_size, \ bits) \ instantiate_kernel( \ - mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ + #mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ fp_dequantize, \ type, \ group_size, \ bits) #define instantiate_quantize_dequantize_modes(type) \ - instantiate_quantize_dequantize(type, "mxfp4", 32, 4) \ - instantiate_quantize_dequantize(type, "nvfp4", 16, 4) \ - instantiate_quantize_dequantize(type, "mxfp8", 32, 8) + instantiate_quantize_dequantize(type, mxfp4, 32, 4) \ + instantiate_quantize_dequantize(type, nvfp4, 16, 4) \ + instantiate_quantize_dequantize(type, mxfp8, 32, 8) #define instantiate_quantized_types(type) \ instantiate_quantized_all_batched(type) \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 3a5b9cec3..e03e5dca2 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -27,14 +27,9 @@ auto get_quantized_kernel_wrapped( int bits, Args... args) { std::string template_def; - auto fname = mode + "_" + func; - if (mode == "affine") { - template_def = get_template_definition( - name, fname, type, group_size, bits, std::forward(args)...); - } else { - template_def = get_template_definition( - name, fname, type, group_size, "uint8_t", std::forward(args)...); - } + std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func; + template_def = get_template_definition( + name, fname, type, group_size, bits, std::forward(args)...); return get_quantized_kernel(d, name, template_def, mode); }