fix METAL quantization in JIT (#2553)

This commit is contained in:
Awni Hannun
2025-08-28 18:26:25 -07:00
committed by GitHub
parent d363a76aa4
commit 827003d568
9 changed files with 154 additions and 89 deletions

View File

@@ -199,7 +199,7 @@ jobs:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \ uv run --no-project python -m xmlrunner discover \
@@ -298,7 +298,7 @@ jobs:
rm ~/miniconda3/miniconda.sh rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate source ~/miniconda3/bin/activate
conda init --all conda init --all
conda create -n env python=<< parameters.python_version >> conda create -n env python=<< parameters.python_version >> -y
conda activate env conda activate env
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.4.0

View File

@@ -77,7 +77,10 @@ if(MLX_METAL_JIT)
make_jit_source(steel/conv/kernels/steel_conv) make_jit_source(steel/conv/kernels/steel_conv)
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h) kernels/steel/conv/loaders/loader_general.h)
make_jit_source(quantized)
make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h)
make_jit_source(gemv_masked) make_jit_source(gemv_masked)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -21,7 +21,9 @@ const char* fft();
const char* gather_axis(); const char* gather_axis();
const char* hadamard(); const char* hadamard();
const char* logsumexp(); const char* logsumexp();
const char* quantized_utils();
const char* quantized(); const char* quantized();
const char* fp4_quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
const char* scatter_axis(); const char* scatter_axis();

View File

@@ -804,13 +804,19 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel( MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def) { const std::string& template_def,
const std::string& mode) {
const auto& lib_name = kernel_name; const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized() concatenate(
<< template_def; kernel_source,
return kernel_source.str(); metal::utils(),
metal::gemm(),
metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
template_def);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -823,6 +829,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x, const array& x,
int group_size, int group_size,
int bits, int bits,
const std::string& mode,
int bm, int bm,
int bn, int bn,
int bk, int bk,
@@ -833,22 +840,40 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source; std::string kernel_source;
concatenate( concatenate(
kernel_source, kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
metal::utils(), if (mode == "affine") {
metal::gemm(), concatenate(
metal::quantized(), kernel_source,
get_template_definition( metal::quantized(),
lib_name, get_template_definition(
"gather_qmm_rhs", lib_name,
get_type_string(x.dtype()), mode + "_gather_qmm_rhs",
group_size, get_type_string(x.dtype()),
bits, group_size,
bm, bits,
bn, bm,
bk, bn,
wm, bk,
wn, wm,
transpose)); wn,
transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
return kernel_source; return kernel_source;
}); });
return d.get_kernel(kernel_name, lib, hash_name, func_consts); return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel( MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& template_def); const std::string& template_def,
const std::string& mode);
MTL::ComputePipelineState* get_gather_qmm_kernel( MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d, metal::Device& d,
@@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x, const array& x,
int group_size, int group_size,
int bits, int bits,
const std::string& mode,
int bm, int bm,
int bn, int bn,
int bk, int bk,

View File

@@ -270,7 +270,7 @@ struct QuantizedBlockLoader {
} }
}; };
template <typename T, int group_size, int D, typename S> template <typename T, int group_size, typename S, int D>
METAL_FUNC void mxfp4_qmv_quad_impl( METAL_FUNC void mxfp4_qmv_quad_impl(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device S* scales,
@@ -633,8 +633,8 @@ METAL_FUNC void mxfp4_qvm_impl(
template < template <
typename T, typename T,
const int group_size, const int group_size,
const bool aligned_N,
typename S, typename S,
const bool aligned_N,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@@ -976,7 +976,7 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride; y += tid.z * output_stride;
} }
template <typename T, int group_size, int D, bool batched, typename S> template <typename T, int group_size, typename S, int D, bool batched>
[[kernel]] void mxfp4_qmv_quad( [[kernel]] void mxfp4_qmv_quad(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device S* scales,
@@ -1014,7 +1014,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
tid); tid);
} }
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, D>( mxfp4_qmv_quad_impl<T, group_size, S, D>(
w, w,
scales, scales,
x, x,
@@ -1029,7 +1029,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
lut); lut);
} }
template <typename T, int group_size, bool batched, typename S> template <typename T, int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv_fast( [[kernel]] void mxfp4_qmv_fast(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device S* scales,
@@ -1069,7 +1069,7 @@ template <typename T, int group_size, bool batched, typename S>
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, const int group_size, bool batched, typename S> template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv( [[kernel]] void mxfp4_qmv(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device S* scales,
@@ -1109,7 +1109,7 @@ template <typename T, const int group_size, bool batched, typename S>
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, const int group_size, bool batched, typename S> template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qvm( [[kernel]] void mxfp4_qvm(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device S* scales,
@@ -1149,7 +1149,7 @@ template <typename T, const int group_size, bool batched, typename S>
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, const int group_size, int split_k = 32, typename S> template <typename T, const int group_size, typename S, int split_k = 32>
[[kernel]] void mxfp4_qvm_split_k( [[kernel]] void mxfp4_qvm_split_k(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device S* scales,
@@ -1205,9 +1205,9 @@ template <typename T, const int group_size, int split_k = 32, typename S>
template < template <
typename T, typename T,
const int group_size, const int group_size,
typename S,
const bool aligned_N, const bool aligned_N,
const bool batched, const bool batched,
typename S,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@@ -1254,15 +1254,15 @@ template <
s_strides, s_strides,
tid); tid);
} }
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>( mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
} }
template < template <
typename T, typename T,
const int group_size, const int group_size,
const bool batched,
typename S, typename S,
const bool batched,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@@ -1468,8 +1468,8 @@ template <typename T, int group_size, typename S>
template < template <
typename T, typename T,
const int group_size, const int group_size,
const bool aligned_N,
typename S, typename S,
const bool aligned_N,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@@ -1526,7 +1526,7 @@ template <
w_strides, w_strides,
s_strides, s_strides,
tid); tid);
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>( mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
} }

View File

@@ -20,8 +20,8 @@
name, \ name, \
type, \ type, \
32, \ 32, \
batched, \ uint8_t, \
uint8_t) batched)
#define instantiate_quantized_aligned(name, type, aligned) \ #define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \ instantiate_kernel( \
@@ -29,8 +29,8 @@
name, \ name, \
type, \ type, \
32, \ 32, \
aligned, \ uint8_t, \
uint8_t) aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \ #define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \ instantiate_kernel( \
@@ -38,9 +38,9 @@
name, \ name, \
type, \ type, \
32, \ 32, \
uint8_t, \
aligned, \ aligned, \
batched, \ batched)
uint8_t)
#define instantiate_quantized_quad(name, type, D, batched) \ #define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \ instantiate_kernel( \
@@ -48,9 +48,9 @@
name, \ name, \
type, \ type, \
32, \ 32, \
uint8_t, \
D, \ D, \
batched, \ batched)
uint8_t)
#define instantiate_quantized_split_k(name, type, split_k) \ #define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \ instantiate_kernel( \
@@ -58,8 +58,8 @@
name, \ name, \
type, \ type, \
32, \ 32, \
split_k, \ uint8_t, \
uint8_t) split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ #define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \ instantiate_kernel( \

View File

@@ -283,6 +283,7 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel( MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string&,
const std::string&) { const std::string&) {
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
@@ -295,6 +296,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array&, const array&,
int, int,
int, int,
const std::string&,
int, int,
int, int,
int, int,

View File

@@ -15,6 +15,28 @@ namespace mlx::core {
namespace { namespace {
template <typename... Args>
auto get_quantized_kernel_wrapped(
metal::Device& d,
const std::string& name,
const std::string& func,
const std::string& mode,
const std::string& type,
int group_size,
int bits,
Args... args) {
std::string template_def;
auto fname = mode + "_" + func;
if (mode == "affine") {
template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...);
} else {
template_def = get_template_definition(
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
}
return get_quantized_kernel(d, name, template_def, mode);
}
inline array inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) { if (!x.flags().row_contiguous) {
@@ -178,10 +200,8 @@ void qmv_quad(
"_d_", "_d_",
K, K,
B > 1 ? "_batch_1" : "_batch_0"); B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition( auto kernel = get_quantized_kernel_wrapped(
kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1); d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -235,15 +255,16 @@ void qmv(
"_b_", "_b_",
bits, bits,
B > 1 ? "_batch_1" : "_batch_0"); B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition( auto kernel = get_quantized_kernel_wrapped(
d,
kname, kname,
mode + (fast ? "_qmv_fast" : "_qmv"), (fast ? "qmv_fast" : "qmv"),
mode,
type_string, type_string,
group_size, group_size,
bits, bits,
B > 1); B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -337,11 +358,11 @@ void qvm_split_k(
bits, bits,
"_spk_", "_spk_",
split_k); split_k);
auto template_def = get_template_definition(
kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel // Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname, template_def); auto kernel = get_quantized_kernel_wrapped(
d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -414,10 +435,8 @@ void qvm(
"_b_", "_b_",
bits, bits,
B > 1 ? "_batch_1" : "_batch_0"); B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition( auto kernel = get_quantized_kernel_wrapped(
kname, mode + "_qvm", type_string, group_size, bits, B > 1); d, kname, "qvm", mode, type_string, group_size, bits, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -476,21 +495,22 @@ void qmm(
transpose ? (aligned ? "_alN_true" : "_alN_false") : "", transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
batched ? "_batch_1" : "_batch_0"); batched ? "_batch_1" : "_batch_0");
std::string template_def; std::string template_def;
MTL::ComputePipelineState* kernel;
if (transpose) { if (transpose) {
template_def = get_template_definition( kernel = get_quantized_kernel_wrapped(
d,
kname, kname,
mode + "_qmm_t", "qmm_t",
mode,
type_string, type_string,
group_size, group_size,
bits, bits,
aligned, aligned,
batched); batched);
} else { } else {
template_def = get_template_definition( kernel = get_quantized_kernel_wrapped(
kname, mode + "_qmm_n", type_string, group_size, bits, batched); d, kname, "qmm_n", mode, type_string, group_size, bits, batched);
} }
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -539,7 +559,6 @@ void gather_qmm(
std::string kname; std::string kname;
kname.reserve(64); kname.reserve(64);
bool aligned = N % 32 == 0; bool aligned = N % 32 == 0;
bool batched = B > 1;
std::string type_string = get_type_string(x.dtype()); std::string type_string = get_type_string(x.dtype());
concatenate( concatenate(
kname, kname,
@@ -550,16 +569,15 @@ void gather_qmm(
"_b_", "_b_",
bits, bits,
transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
std::string template_def; MTL::ComputePipelineState* kernel;
if (transpose) { if (transpose) {
template_def = get_template_definition( kernel = get_quantized_kernel_wrapped(
kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned); d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned);
} else { } else {
template_def = get_template_definition( kernel = get_quantized_kernel_wrapped(
kname, mode + "_gather_qmm_n", type_string, group_size, bits); d, kname, "gather_qmm_n", mode, type_string, group_size, bits);
} }
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -617,14 +635,16 @@ void gather_qmv(
group_size, group_size,
"_b_", "_b_",
bits); bits);
auto template_def = get_template_definition(
auto kernel = get_quantized_kernel_wrapped(
d,
kname, kname,
mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"), (fast ? "gather_qmv_fast" : "gather_qmv"),
mode,
type_string, type_string,
group_size, group_size,
bits); bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -680,10 +700,8 @@ void gather_qvm(
group_size, group_size,
"_b_", "_b_",
bits); bits);
auto template_def = get_template_definition( auto kernel = get_quantized_kernel_wrapped(
kname, mode + "_gather_qvm", type_string, group_size, bits); d, kname, "gather_qvm", mode, type_string, group_size, bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
@@ -806,6 +824,7 @@ void gather_qmm_rhs(
x, x,
group_size, group_size,
bits, bits,
mode,
bm, bm,
bn, bn,
bk, bk,
@@ -1039,15 +1058,27 @@ void fast::Quantize::eval_gpu(
compute_encoder.set_output_array(biases, 3); compute_encoder.set_output_array(biases, 3);
} }
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype()) auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype()); : get_type_string(w_pre.dtype());
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize"; std::string kname;
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_" concatenate(
<< bits_; kname,
auto template_def = get_template_definition( dequantize_ ? "affine_dequantize" : "affine_quantize",
kname.str(), kernel_func, type_string, group_size_, bits_); "_",
auto kernel = get_quantized_kernel(d, kname.str(), template_def); type_string,
"_gs_",
group_size_,
"_b_",
bits_);
auto kernel = get_quantized_kernel_wrapped(
d,
kname,
dequantize_ ? "dequantize" : "quantize",
"affine",
type_string,
group_size_,
bits_);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Treat uint32 as uint8 in kernel // Treat uint32 as uint8 in kernel