fix METAL quantization in JIT (#2553)

This commit is contained in:
Awni Hannun
2025-08-28 18:26:25 -07:00
committed by GitHub
parent d363a76aa4
commit 827003d568
9 changed files with 154 additions and 89 deletions

View File

@@ -199,7 +199,7 @@ jobs:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e .
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
@@ -298,7 +298,7 @@ jobs:
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >>
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0

View File

@@ -77,7 +77,10 @@ if(MLX_METAL_JIT)
make_jit_source(steel/conv/kernels/steel_conv)
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h)
make_jit_source(quantized)
make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h)
make_jit_source(gemv_masked)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -21,7 +21,9 @@ const char* fft();
const char* gather_axis();
const char* hadamard();
const char* logsumexp();
const char* quantized_utils();
const char* quantized();
const char* fp4_quantized();
const char* ternary();
const char* scan();
const char* scatter_axis();

View File

@@ -804,13 +804,19 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def) {
const std::string& template_def,
const std::string& mode) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
return kernel_source.str();
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
template_def);
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
@@ -823,6 +829,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x,
int group_size,
int bits,
const std::string& mode,
int bm,
int bn,
int bk,
@@ -833,22 +840,40 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source;
concatenate(
kernel_source,
metal::utils(),
metal::gemm(),
metal::quantized(),
get_template_definition(
lib_name,
"gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
if (mode == "affine") {
concatenate(
kernel_source,
metal::quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def);
const std::string& template_def,
const std::string& mode);
MTL::ComputePipelineState* get_gather_qmm_kernel(
metal::Device& d,
@@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array& x,
int group_size,
int bits,
const std::string& mode,
int bm,
int bn,
int bk,

View File

@@ -270,7 +270,7 @@ struct QuantizedBlockLoader {
}
};
template <typename T, int group_size, int D, typename S>
template <typename T, int group_size, typename S, int D>
METAL_FUNC void mxfp4_qmv_quad_impl(
const device uint32_t* w,
const device S* scales,
@@ -633,8 +633,8 @@ METAL_FUNC void mxfp4_qvm_impl(
template <
typename T,
const int group_size,
const bool aligned_N,
typename S,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -976,7 +976,7 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, int D, bool batched, typename S>
template <typename T, int group_size, typename S, int D, bool batched>
[[kernel]] void mxfp4_qmv_quad(
const device uint32_t* w,
const device S* scales,
@@ -1014,7 +1014,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
tid);
}
threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, D>(
mxfp4_qmv_quad_impl<T, group_size, S, D>(
w,
scales,
x,
@@ -1029,7 +1029,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
lut);
}
template <typename T, int group_size, bool batched, typename S>
template <typename T, int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv_fast(
const device uint32_t* w,
const device S* scales,
@@ -1069,7 +1069,7 @@ template <typename T, int group_size, bool batched, typename S>
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, bool batched, typename S>
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv(
const device uint32_t* w,
const device S* scales,
@@ -1109,7 +1109,7 @@ template <typename T, const int group_size, bool batched, typename S>
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, bool batched, typename S>
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qvm(
const device uint32_t* w,
const device S* scales,
@@ -1149,7 +1149,7 @@ template <typename T, const int group_size, bool batched, typename S>
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, int split_k = 32, typename S>
template <typename T, const int group_size, typename S, int split_k = 32>
[[kernel]] void mxfp4_qvm_split_k(
const device uint32_t* w,
const device S* scales,
@@ -1205,9 +1205,9 @@ template <typename T, const int group_size, int split_k = 32, typename S>
template <
typename T,
const int group_size,
typename S,
const bool aligned_N,
const bool batched,
typename S,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1254,15 +1254,15 @@ template <
s_strides,
tid);
}
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
const bool batched,
typename S,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1468,8 +1468,8 @@ template <typename T, int group_size, typename S>
template <
typename T,
const int group_size,
const bool aligned_N,
typename S,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1526,7 +1526,7 @@ template <
w_strides,
s_strides,
tid);
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}

View File

@@ -20,8 +20,8 @@
name, \
type, \
32, \
batched, \
uint8_t)
uint8_t, \
batched)
#define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \
@@ -29,8 +29,8 @@
name, \
type, \
32, \
aligned, \
uint8_t)
uint8_t, \
aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \
@@ -38,9 +38,9 @@
name, \
type, \
32, \
uint8_t, \
aligned, \
batched, \
uint8_t)
batched)
#define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \
@@ -48,9 +48,9 @@
name, \
type, \
32, \
uint8_t, \
D, \
batched, \
uint8_t)
batched)
#define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \
@@ -58,8 +58,8 @@
name, \
type, \
32, \
split_k, \
uint8_t)
uint8_t, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \

View File

@@ -283,6 +283,7 @@ MTL::ComputePipelineState* get_fft_kernel(
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&,
const std::string&) {
return d.get_kernel(kernel_name);
}
@@ -295,6 +296,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
const array&,
int,
int,
const std::string&,
int,
int,
int,

View File

@@ -15,6 +15,28 @@ namespace mlx::core {
namespace {
template <typename... Args>
auto get_quantized_kernel_wrapped(
metal::Device& d,
const std::string& name,
const std::string& func,
const std::string& mode,
const std::string& type,
int group_size,
int bits,
Args... args) {
std::string template_def;
auto fname = mode + "_" + func;
if (mode == "affine") {
template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...);
} else {
template_def = get_template_definition(
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
}
return get_quantized_kernel(d, name, template_def, mode);
}
inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) {
@@ -178,10 +200,8 @@ void qmv_quad(
"_d_",
K,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -235,15 +255,16 @@ void qmv(
"_b_",
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
auto kernel = get_quantized_kernel_wrapped(
d,
kname,
mode + (fast ? "_qmv_fast" : "_qmv"),
(fast ? "qmv_fast" : "qmv"),
mode,
type_string,
group_size,
bits,
B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -337,11 +358,11 @@ void qvm_split_k(
bits,
"_spk_",
split_k);
auto template_def = get_template_definition(
kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname, template_def);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -414,10 +435,8 @@ void qvm(
"_b_",
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, mode + "_qvm", type_string, group_size, bits, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "qvm", mode, type_string, group_size, bits, B > 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -476,21 +495,22 @@ void qmm(
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
batched ? "_batch_1" : "_batch_0");
std::string template_def;
MTL::ComputePipelineState* kernel;
if (transpose) {
template_def = get_template_definition(
kernel = get_quantized_kernel_wrapped(
d,
kname,
mode + "_qmm_t",
"qmm_t",
mode,
type_string,
group_size,
bits,
aligned,
batched);
} else {
template_def = get_template_definition(
kname, mode + "_qmm_n", type_string, group_size, bits, batched);
kernel = get_quantized_kernel_wrapped(
d, kname, "qmm_n", mode, type_string, group_size, bits, batched);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -539,7 +559,6 @@ void gather_qmm(
std::string kname;
kname.reserve(64);
bool aligned = N % 32 == 0;
bool batched = B > 1;
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
@@ -550,16 +569,15 @@ void gather_qmm(
"_b_",
bits,
transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
std::string template_def;
MTL::ComputePipelineState* kernel;
if (transpose) {
template_def = get_template_definition(
kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned);
kernel = get_quantized_kernel_wrapped(
d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned);
} else {
template_def = get_template_definition(
kname, mode + "_gather_qmm_n", type_string, group_size, bits);
kernel = get_quantized_kernel_wrapped(
d, kname, "gather_qmm_n", mode, type_string, group_size, bits);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -617,14 +635,16 @@ void gather_qmv(
group_size,
"_b_",
bits);
auto template_def = get_template_definition(
auto kernel = get_quantized_kernel_wrapped(
d,
kname,
mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"),
(fast ? "gather_qmv_fast" : "gather_qmv"),
mode,
type_string,
group_size,
bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -680,10 +700,8 @@ void gather_qvm(
group_size,
"_b_",
bits);
auto template_def = get_template_definition(
kname, mode + "_gather_qvm", type_string, group_size, bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto kernel = get_quantized_kernel_wrapped(
d, kname, "gather_qvm", mode, type_string, group_size, bits);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -806,6 +824,7 @@ void gather_qmm_rhs(
x,
group_size,
bits,
mode,
bm,
bn,
bk,
@@ -1039,15 +1058,27 @@ void fast::Quantize::eval_gpu(
compute_encoder.set_output_array(biases, 3);
}
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype());
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
std::string kname;
concatenate(
kname,
dequantize_ ? "affine_dequantize" : "affine_quantize",
"_",
type_string,
"_gs_",
group_size_,
"_b_",
bits_);
auto kernel = get_quantized_kernel_wrapped(
d,
kname,
dequantize_ ? "dequantize" : "quantize",
"affine",
type_string,
group_size_,
bits_);
compute_encoder.set_compute_pipeline_state(kernel);
// Treat uint32 as uint8 in kernel