mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
fix METAL quantization in JIT (#2553)
This commit is contained in:
@@ -199,7 +199,7 @@ jobs:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e .
|
||||
uv pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
@@ -298,7 +298,7 @@ jobs:
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
source ~/miniconda3/bin/activate
|
||||
conda init --all
|
||||
conda create -n env python=<< parameters.python_version >>
|
||||
conda create -n env python=<< parameters.python_version >> -y
|
||||
conda activate env
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
|
@@ -77,7 +77,10 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(steel/conv/kernels/steel_conv)
|
||||
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h)
|
||||
make_jit_source(quantized)
|
||||
|
||||
make_jit_source(quantized_utils)
|
||||
make_jit_source(quantized kernels/quantized_utils.h)
|
||||
make_jit_source(fp4_quantized kernels/quantized_utils.h)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||
|
@@ -21,7 +21,9 @@ const char* fft();
|
||||
const char* gather_axis();
|
||||
const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized_utils();
|
||||
const char* quantized();
|
||||
const char* fp4_quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* scatter_axis();
|
||||
|
@@ -804,13 +804,19 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def) {
|
||||
const std::string& template_def,
|
||||
const std::string& mode) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
||||
<< template_def;
|
||||
return kernel_source.str();
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized_utils(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
|
||||
template_def);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -823,6 +829,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& mode,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
@@ -833,22 +840,40 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
"gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
||||
if (mode == "affine") {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
} else {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::fp4_quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
"uint8_t",
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
|
@@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def);
|
||||
const std::string& template_def,
|
||||
const std::string& mode);
|
||||
|
||||
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
metal::Device& d,
|
||||
@@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& mode,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
|
@@ -270,7 +270,7 @@ struct QuantizedBlockLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int group_size, int D, typename S>
|
||||
template <typename T, int group_size, typename S, int D>
|
||||
METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
@@ -633,8 +633,8 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const bool aligned_N,
|
||||
typename S,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -976,7 +976,7 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int D, bool batched, typename S>
|
||||
template <typename T, int group_size, typename S, int D, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_quad(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
@@ -1014,7 +1014,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_quad_impl<T, group_size, D>(
|
||||
mxfp4_qmv_quad_impl<T, group_size, S, D>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
@@ -1029,7 +1029,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
||||
lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, bool batched, typename S>
|
||||
template <typename T, int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_fast(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
@@ -1069,7 +1069,7 @@ template <typename T, int group_size, bool batched, typename S>
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, bool batched, typename S>
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
@@ -1109,7 +1109,7 @@ template <typename T, const int group_size, bool batched, typename S>
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, bool batched, typename S>
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qvm(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
@@ -1149,7 +1149,7 @@ template <typename T, const int group_size, bool batched, typename S>
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int split_k = 32, typename S>
|
||||
template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
[[kernel]] void mxfp4_qvm_split_k(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
@@ -1205,9 +1205,9 @@ template <typename T, const int group_size, int split_k = 32, typename S>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
typename S,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -1254,15 +1254,15 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const bool batched,
|
||||
typename S,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -1468,8 +1468,8 @@ template <typename T, int group_size, typename S>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const bool aligned_N,
|
||||
typename S,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@@ -1526,7 +1526,7 @@ template <
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
|
@@ -20,8 +20,8 @@
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
batched, \
|
||||
uint8_t)
|
||||
uint8_t, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, aligned) \
|
||||
instantiate_kernel( \
|
||||
@@ -29,8 +29,8 @@
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
aligned, \
|
||||
uint8_t)
|
||||
uint8_t, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
@@ -38,9 +38,9 @@
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
aligned, \
|
||||
batched, \
|
||||
uint8_t)
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_quad(name, type, D, batched) \
|
||||
instantiate_kernel( \
|
||||
@@ -48,9 +48,9 @@
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
D, \
|
||||
batched, \
|
||||
uint8_t)
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, split_k) \
|
||||
instantiate_kernel( \
|
||||
@@ -58,8 +58,8 @@
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
split_k, \
|
||||
uint8_t)
|
||||
uint8_t, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
|
@@ -283,6 +283,7 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string&,
|
||||
const std::string&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
@@ -295,6 +296,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
const array&,
|
||||
int,
|
||||
int,
|
||||
const std::string&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
|
@@ -15,6 +15,28 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename... Args>
|
||||
auto get_quantized_kernel_wrapped(
|
||||
metal::Device& d,
|
||||
const std::string& name,
|
||||
const std::string& func,
|
||||
const std::string& mode,
|
||||
const std::string& type,
|
||||
int group_size,
|
||||
int bits,
|
||||
Args... args) {
|
||||
std::string template_def;
|
||||
auto fname = mode + "_" + func;
|
||||
if (mode == "affine") {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
|
||||
}
|
||||
return get_quantized_kernel(d, name, template_def, mode);
|
||||
}
|
||||
|
||||
inline array
|
||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
@@ -178,10 +200,8 @@ void qmv_quad(
|
||||
"_d_",
|
||||
K,
|
||||
B > 1 ? "_batch_1" : "_batch_0");
|
||||
auto template_def = get_template_definition(
|
||||
kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -235,15 +255,16 @@ void qmv(
|
||||
"_b_",
|
||||
bits,
|
||||
B > 1 ? "_batch_1" : "_batch_0");
|
||||
auto template_def = get_template_definition(
|
||||
auto kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
mode + (fast ? "_qmv_fast" : "_qmv"),
|
||||
(fast ? "qmv_fast" : "qmv"),
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
B > 1);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -337,11 +358,11 @@ void qvm_split_k(
|
||||
bits,
|
||||
"_spk_",
|
||||
split_k);
|
||||
auto template_def = get_template_definition(
|
||||
kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -414,10 +435,8 @@ void qvm(
|
||||
"_b_",
|
||||
bits,
|
||||
B > 1 ? "_batch_1" : "_batch_0");
|
||||
auto template_def = get_template_definition(
|
||||
kname, mode + "_qvm", type_string, group_size, bits, B > 1);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "qvm", mode, type_string, group_size, bits, B > 1);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -476,21 +495,22 @@ void qmm(
|
||||
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
|
||||
batched ? "_batch_1" : "_batch_0");
|
||||
std::string template_def;
|
||||
MTL::ComputePipelineState* kernel;
|
||||
if (transpose) {
|
||||
template_def = get_template_definition(
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
mode + "_qmm_t",
|
||||
"qmm_t",
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
aligned,
|
||||
batched);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname, mode + "_qmm_n", type_string, group_size, bits, batched);
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "qmm_n", mode, type_string, group_size, bits, batched);
|
||||
}
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -539,7 +559,6 @@ void gather_qmm(
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
bool aligned = N % 32 == 0;
|
||||
bool batched = B > 1;
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
@@ -550,16 +569,15 @@ void gather_qmm(
|
||||
"_b_",
|
||||
bits,
|
||||
transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
|
||||
std::string template_def;
|
||||
MTL::ComputePipelineState* kernel;
|
||||
if (transpose) {
|
||||
template_def = get_template_definition(
|
||||
kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned);
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname, mode + "_gather_qmm_n", type_string, group_size, bits);
|
||||
kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "gather_qmm_n", mode, type_string, group_size, bits);
|
||||
}
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -617,14 +635,16 @@ void gather_qmv(
|
||||
group_size,
|
||||
"_b_",
|
||||
bits);
|
||||
auto template_def = get_template_definition(
|
||||
|
||||
auto kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"),
|
||||
(fast ? "gather_qmv_fast" : "gather_qmv"),
|
||||
mode,
|
||||
type_string,
|
||||
group_size,
|
||||
bits);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -680,10 +700,8 @@ void gather_qvm(
|
||||
group_size,
|
||||
"_b_",
|
||||
bits);
|
||||
auto template_def = get_template_definition(
|
||||
kname, mode + "_gather_qvm", type_string, group_size, bits);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto kernel = get_quantized_kernel_wrapped(
|
||||
d, kname, "gather_qvm", mode, type_string, group_size, bits);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -806,6 +824,7 @@ void gather_qmm_rhs(
|
||||
x,
|
||||
group_size,
|
||||
bits,
|
||||
mode,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
@@ -1039,15 +1058,27 @@ void fast::Quantize::eval_gpu(
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||
: get_type_string(w_pre.dtype());
|
||||
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
|
||||
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), kernel_func, type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
std::string kname;
|
||||
concatenate(
|
||||
kname,
|
||||
dequantize_ ? "affine_dequantize" : "affine_quantize",
|
||||
"_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size_,
|
||||
"_b_",
|
||||
bits_);
|
||||
auto kernel = get_quantized_kernel_wrapped(
|
||||
d,
|
||||
kname,
|
||||
dequantize_ ? "dequantize" : "quantize",
|
||||
"affine",
|
||||
type_string,
|
||||
group_size_,
|
||||
bits_);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Treat uint32 as uint8 in kernel
|
||||
|
Reference in New Issue
Block a user