From 827003d5684081bd150175bee5fd2630d6fae8a5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 28 Aug 2025 18:26:25 -0700 Subject: [PATCH] fix METAL quantization in JIT (#2553) --- .circleci/config.yml | 4 +- mlx/backend/metal/CMakeLists.txt | 5 +- mlx/backend/metal/jit/includes.h | 2 + mlx/backend/metal/jit_kernels.cpp | 67 +++++++---- mlx/backend/metal/kernels.h | 4 +- mlx/backend/metal/kernels/fp4_quantized.h | 26 ++-- mlx/backend/metal/kernels/fp4_quantized.metal | 20 ++-- mlx/backend/metal/nojit_kernels.cpp | 2 + mlx/backend/metal/quantized.cpp | 113 +++++++++++------- 9 files changed, 154 insertions(+), 89 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 85c3b7967..720f99fcc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -199,7 +199,7 @@ jobs: name: Run Python tests with JIT command: | CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ - uv pip install -e . + uv pip install -e . -v LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ METAL_DEBUG_ERROR_MODE=0 \ uv run --no-project python -m xmlrunner discover \ @@ -298,7 +298,7 @@ jobs: rm ~/miniconda3/miniconda.sh source ~/miniconda3/bin/activate conda init --all - conda create -n env python=<< parameters.python_version >> + conda create -n env python=<< parameters.python_version >> -y conda activate env pip install --upgrade cmake pip install nanobind==2.4.0 diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index ccdd83202..8f842259c 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -77,7 +77,10 @@ if(MLX_METAL_JIT) make_jit_source(steel/conv/kernels/steel_conv) make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h kernels/steel/conv/loaders/loader_general.h) - make_jit_source(quantized) + + make_jit_source(quantized_utils) + make_jit_source(quantized kernels/quantized_utils.h) + make_jit_source(fp4_quantized kernels/quantized_utils.h) make_jit_source(gemv_masked) else() target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index b380a8374..068a3ff08 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -21,7 +21,9 @@ const char* fft(); const char* gather_axis(); const char* hadamard(); const char* logsumexp(); +const char* quantized_utils(); const char* quantized(); +const char* fp4_quantized(); const char* ternary(); const char* scan(); const char* scatter_axis(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 6ae72e0aa..cc8addf9f 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -804,13 +804,19 @@ MTL::ComputePipelineState* get_fft_kernel( MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, - const std::string& template_def) { + const std::string& template_def, + const std::string& mode) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { - std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::gemm() << metal::quantized() - << template_def; - return kernel_source.str(); + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::quantized_utils(), + (mode == "affine") ? metal::quantized() : metal::fp4_quantized(), + template_def); + return kernel_source; }); return d.get_kernel(kernel_name, lib); } @@ -823,6 +829,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( const array& x, int group_size, int bits, + const std::string& mode, int bm, int bn, int bk, @@ -833,22 +840,40 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; concatenate( - kernel_source, - metal::utils(), - metal::gemm(), - metal::quantized(), - get_template_definition( - lib_name, - "gather_qmm_rhs", - get_type_string(x.dtype()), - group_size, - bits, - bm, - bn, - bk, - wm, - wn, - transpose)); + kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm()); + if (mode == "affine") { + concatenate( + kernel_source, + metal::quantized(), + get_template_definition( + lib_name, + mode + "_gather_qmm_rhs", + get_type_string(x.dtype()), + group_size, + bits, + bm, + bn, + bk, + wm, + wn, + transpose)); + } else { + concatenate( + kernel_source, + metal::fp4_quantized(), + get_template_definition( + lib_name, + mode + "_gather_qmm_rhs", + get_type_string(x.dtype()), + group_size, + "uint8_t", + bm, + bn, + bk, + wm, + wn, + transpose)); + } return kernel_source; }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index ca29ca52e..0656f3c0f 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel( MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, - const std::string& template_def); + const std::string& template_def, + const std::string& mode); MTL::ComputePipelineState* get_gather_qmm_kernel( metal::Device& d, @@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( const array& x, int group_size, int bits, + const std::string& mode, int bm, int bn, int bk, diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp4_quantized.h index b5a8918e4..0b22dc1e5 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.h +++ b/mlx/backend/metal/kernels/fp4_quantized.h @@ -270,7 +270,7 @@ struct QuantizedBlockLoader { } }; -template +template METAL_FUNC void mxfp4_qmv_quad_impl( const device uint32_t* w, const device S* scales, @@ -633,8 +633,8 @@ METAL_FUNC void mxfp4_qvm_impl( template < typename T, const int group_size, - const bool aligned_N, typename S, + const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> @@ -976,7 +976,7 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template +template [[kernel]] void mxfp4_qmv_quad( const device uint32_t* w, const device S* scales, @@ -1014,7 +1014,7 @@ template tid); } threadgroup float lut[16]; - mxfp4_qmv_quad_impl( + mxfp4_qmv_quad_impl( w, scales, x, @@ -1029,7 +1029,7 @@ template lut); } -template +template [[kernel]] void mxfp4_qmv_fast( const device uint32_t* w, const device S* scales, @@ -1069,7 +1069,7 @@ template w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template +template [[kernel]] void mxfp4_qmv( const device uint32_t* w, const device S* scales, @@ -1109,7 +1109,7 @@ template w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template +template [[kernel]] void mxfp4_qvm( const device uint32_t* w, const device S* scales, @@ -1149,7 +1149,7 @@ template w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); } -template +template [[kernel]] void mxfp4_qvm_split_k( const device uint32_t* w, const device S* scales, @@ -1205,9 +1205,9 @@ template template < typename T, const int group_size, + typename S, const bool aligned_N, const bool batched, - typename S, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1254,15 +1254,15 @@ template < s_strides, tid); } - mxfp4_qmm_t_impl( + mxfp4_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } template < typename T, const int group_size, - const bool batched, typename S, + const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1468,8 +1468,8 @@ template template < typename T, const int group_size, - const bool aligned_N, typename S, + const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1526,7 +1526,7 @@ template < w_strides, s_strides, tid); - mxfp4_qmm_t_impl( + mxfp4_qmm_t_impl( w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); } diff --git a/mlx/backend/metal/kernels/fp4_quantized.metal b/mlx/backend/metal/kernels/fp4_quantized.metal index c982b4bf4..6b2daf88c 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.metal +++ b/mlx/backend/metal/kernels/fp4_quantized.metal @@ -20,8 +20,8 @@ name, \ type, \ 32, \ - batched, \ - uint8_t) + uint8_t, \ + batched) #define instantiate_quantized_aligned(name, type, aligned) \ instantiate_kernel( \ @@ -29,8 +29,8 @@ name, \ type, \ 32, \ - aligned, \ - uint8_t) + uint8_t, \ + aligned) #define instantiate_quantized_aligned_batched(name, type, aligned, batched) \ instantiate_kernel( \ @@ -38,9 +38,9 @@ name, \ type, \ 32, \ + uint8_t, \ aligned, \ - batched, \ - uint8_t) + batched) #define instantiate_quantized_quad(name, type, D, batched) \ instantiate_kernel( \ @@ -48,9 +48,9 @@ name, \ type, \ 32, \ + uint8_t, \ D, \ - batched, \ - uint8_t) + batched) #define instantiate_quantized_split_k(name, type, split_k) \ instantiate_kernel( \ @@ -58,8 +58,8 @@ name, \ type, \ 32, \ - split_k, \ - uint8_t) + uint8_t, \ + split_k) #define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ instantiate_kernel( \ diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index a689a793e..109dd8df7 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -283,6 +283,7 @@ MTL::ComputePipelineState* get_fft_kernel( MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, + const std::string&, const std::string&) { return d.get_kernel(kernel_name); } @@ -295,6 +296,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( const array&, int, int, + const std::string&, int, int, int, diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index f8bc0342e..333f3971f 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -15,6 +15,28 @@ namespace mlx::core { namespace { +template +auto get_quantized_kernel_wrapped( + metal::Device& d, + const std::string& name, + const std::string& func, + const std::string& mode, + const std::string& type, + int group_size, + int bits, + Args... args) { + std::string template_def; + auto fname = mode + "_" + func; + if (mode == "affine") { + template_def = get_template_definition( + name, fname, type, group_size, bits, std::forward(args)...); + } else { + template_def = get_template_definition( + name, fname, type, group_size, "uint8_t", std::forward(args)...); + } + return get_quantized_kernel(d, name, template_def, mode); +} + inline array ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { if (!x.flags().row_contiguous) { @@ -178,10 +200,8 @@ void qmv_quad( "_d_", K, B > 1 ? "_batch_1" : "_batch_0"); - auto template_def = get_template_definition( - kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1); - - auto kernel = get_quantized_kernel(d, kname, template_def); + auto kernel = get_quantized_kernel_wrapped( + d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -235,15 +255,16 @@ void qmv( "_b_", bits, B > 1 ? "_batch_1" : "_batch_0"); - auto template_def = get_template_definition( + auto kernel = get_quantized_kernel_wrapped( + d, kname, - mode + (fast ? "_qmv_fast" : "_qmv"), + (fast ? "qmv_fast" : "qmv"), + mode, type_string, group_size, bits, B > 1); - auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -337,11 +358,11 @@ void qvm_split_k( bits, "_spk_", split_k); - auto template_def = get_template_definition( - kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k); // Encode and dispatch kernel - auto kernel = get_quantized_kernel(d, kname, template_def); + auto kernel = get_quantized_kernel_wrapped( + d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -414,10 +435,8 @@ void qvm( "_b_", bits, B > 1 ? "_batch_1" : "_batch_0"); - auto template_def = get_template_definition( - kname, mode + "_qvm", type_string, group_size, bits, B > 1); - - auto kernel = get_quantized_kernel(d, kname, template_def); + auto kernel = get_quantized_kernel_wrapped( + d, kname, "qvm", mode, type_string, group_size, bits, B > 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -476,21 +495,22 @@ void qmm( transpose ? (aligned ? "_alN_true" : "_alN_false") : "", batched ? "_batch_1" : "_batch_0"); std::string template_def; + MTL::ComputePipelineState* kernel; if (transpose) { - template_def = get_template_definition( + kernel = get_quantized_kernel_wrapped( + d, kname, - mode + "_qmm_t", + "qmm_t", + mode, type_string, group_size, bits, aligned, batched); } else { - template_def = get_template_definition( - kname, mode + "_qmm_n", type_string, group_size, bits, batched); + kernel = get_quantized_kernel_wrapped( + d, kname, "qmm_n", mode, type_string, group_size, bits, batched); } - - auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -539,7 +559,6 @@ void gather_qmm( std::string kname; kname.reserve(64); bool aligned = N % 32 == 0; - bool batched = B > 1; std::string type_string = get_type_string(x.dtype()); concatenate( kname, @@ -550,16 +569,15 @@ void gather_qmm( "_b_", bits, transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); - std::string template_def; + MTL::ComputePipelineState* kernel; if (transpose) { - template_def = get_template_definition( - kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned); + kernel = get_quantized_kernel_wrapped( + d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned); } else { - template_def = get_template_definition( - kname, mode + "_gather_qmm_n", type_string, group_size, bits); + kernel = get_quantized_kernel_wrapped( + d, kname, "gather_qmm_n", mode, type_string, group_size, bits); } - auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -617,14 +635,16 @@ void gather_qmv( group_size, "_b_", bits); - auto template_def = get_template_definition( + + auto kernel = get_quantized_kernel_wrapped( + d, kname, - mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"), + (fast ? "gather_qmv_fast" : "gather_qmv"), + mode, type_string, group_size, bits); - auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -680,10 +700,8 @@ void gather_qvm( group_size, "_b_", bits); - auto template_def = get_template_definition( - kname, mode + "_gather_qvm", type_string, group_size, bits); - - auto kernel = get_quantized_kernel(d, kname, template_def); + auto kernel = get_quantized_kernel_wrapped( + d, kname, "gather_qvm", mode, type_string, group_size, bits); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -806,6 +824,7 @@ void gather_qmm_rhs( x, group_size, bits, + mode, bm, bn, bk, @@ -1039,15 +1058,27 @@ void fast::Quantize::eval_gpu( compute_encoder.set_output_array(biases, 3); } - std::ostringstream kname; auto type_string = dequantize_ ? get_type_string(out.dtype()) : get_type_string(w_pre.dtype()); - auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize"; - kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; - auto template_def = get_template_definition( - kname.str(), kernel_func, type_string, group_size_, bits_); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); + std::string kname; + concatenate( + kname, + dequantize_ ? "affine_dequantize" : "affine_quantize", + "_", + type_string, + "_gs_", + group_size_, + "_b_", + bits_); + auto kernel = get_quantized_kernel_wrapped( + d, + kname, + dequantize_ ? "dequantize" : "quantize", + "affine", + type_string, + group_size_, + bits_); + compute_encoder.set_compute_pipeline_state(kernel); // Treat uint32 as uint8 in kernel