mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
fix METAL quantization in JIT (#2553)
This commit is contained in:
@@ -199,7 +199,7 @@ jobs:
|
|||||||
name: Run Python tests with JIT
|
name: Run Python tests with JIT
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
uv pip install -e .
|
uv pip install -e . -v
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
uv run --no-project python -m xmlrunner discover \
|
uv run --no-project python -m xmlrunner discover \
|
||||||
@@ -298,7 +298,7 @@ jobs:
|
|||||||
rm ~/miniconda3/miniconda.sh
|
rm ~/miniconda3/miniconda.sh
|
||||||
source ~/miniconda3/bin/activate
|
source ~/miniconda3/bin/activate
|
||||||
conda init --all
|
conda init --all
|
||||||
conda create -n env python=<< parameters.python_version >>
|
conda create -n env python=<< parameters.python_version >> -y
|
||||||
conda activate env
|
conda activate env
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.4.0
|
pip install nanobind==2.4.0
|
||||||
|
@@ -77,7 +77,10 @@ if(MLX_METAL_JIT)
|
|||||||
make_jit_source(steel/conv/kernels/steel_conv)
|
make_jit_source(steel/conv/kernels/steel_conv)
|
||||||
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
||||||
kernels/steel/conv/loaders/loader_general.h)
|
kernels/steel/conv/loaders/loader_general.h)
|
||||||
make_jit_source(quantized)
|
|
||||||
|
make_jit_source(quantized_utils)
|
||||||
|
make_jit_source(quantized kernels/quantized_utils.h)
|
||||||
|
make_jit_source(fp4_quantized kernels/quantized_utils.h)
|
||||||
make_jit_source(gemv_masked)
|
make_jit_source(gemv_masked)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||||
|
@@ -21,7 +21,9 @@ const char* fft();
|
|||||||
const char* gather_axis();
|
const char* gather_axis();
|
||||||
const char* hadamard();
|
const char* hadamard();
|
||||||
const char* logsumexp();
|
const char* logsumexp();
|
||||||
|
const char* quantized_utils();
|
||||||
const char* quantized();
|
const char* quantized();
|
||||||
|
const char* fp4_quantized();
|
||||||
const char* ternary();
|
const char* ternary();
|
||||||
const char* scan();
|
const char* scan();
|
||||||
const char* scatter_axis();
|
const char* scatter_axis();
|
||||||
|
@@ -804,13 +804,19 @@ MTL::ComputePipelineState* get_fft_kernel(
|
|||||||
MTL::ComputePipelineState* get_quantized_kernel(
|
MTL::ComputePipelineState* get_quantized_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::string& template_def) {
|
const std::string& template_def,
|
||||||
|
const std::string& mode) {
|
||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::string kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
concatenate(
|
||||||
<< template_def;
|
kernel_source,
|
||||||
return kernel_source.str();
|
metal::utils(),
|
||||||
|
metal::gemm(),
|
||||||
|
metal::quantized_utils(),
|
||||||
|
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
|
||||||
|
template_def);
|
||||||
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
@@ -823,6 +829,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
const array& x,
|
const array& x,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
|
const std::string& mode,
|
||||||
int bm,
|
int bm,
|
||||||
int bn,
|
int bn,
|
||||||
int bk,
|
int bk,
|
||||||
@@ -832,14 +839,15 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
const auto& lib_name = kernel_name;
|
const auto& lib_name = kernel_name;
|
||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::string kernel_source;
|
std::string kernel_source;
|
||||||
|
concatenate(
|
||||||
|
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
||||||
|
if (mode == "affine") {
|
||||||
concatenate(
|
concatenate(
|
||||||
kernel_source,
|
kernel_source,
|
||||||
metal::utils(),
|
|
||||||
metal::gemm(),
|
|
||||||
metal::quantized(),
|
metal::quantized(),
|
||||||
get_template_definition(
|
get_template_definition(
|
||||||
lib_name,
|
lib_name,
|
||||||
"gather_qmm_rhs",
|
mode + "_gather_qmm_rhs",
|
||||||
get_type_string(x.dtype()),
|
get_type_string(x.dtype()),
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
@@ -849,6 +857,23 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
wm,
|
wm,
|
||||||
wn,
|
wn,
|
||||||
transpose));
|
transpose));
|
||||||
|
} else {
|
||||||
|
concatenate(
|
||||||
|
kernel_source,
|
||||||
|
metal::fp4_quantized(),
|
||||||
|
get_template_definition(
|
||||||
|
lib_name,
|
||||||
|
mode + "_gather_qmm_rhs",
|
||||||
|
get_type_string(x.dtype()),
|
||||||
|
group_size,
|
||||||
|
"uint8_t",
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
transpose));
|
||||||
|
}
|
||||||
return kernel_source;
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
@@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel(
|
|||||||
MTL::ComputePipelineState* get_quantized_kernel(
|
MTL::ComputePipelineState* get_quantized_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::string& template_def);
|
const std::string& template_def,
|
||||||
|
const std::string& mode);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
const array& x,
|
const array& x,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
|
const std::string& mode,
|
||||||
int bm,
|
int bm,
|
||||||
int bn,
|
int bn,
|
||||||
int bk,
|
int bk,
|
||||||
|
@@ -270,7 +270,7 @@ struct QuantizedBlockLoader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int group_size, int D, typename S>
|
template <typename T, int group_size, typename S, int D>
|
||||||
METAL_FUNC void mxfp4_qmv_quad_impl(
|
METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device S* scales,
|
||||||
@@ -633,8 +633,8 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const bool aligned_N,
|
|
||||||
typename S,
|
typename S,
|
||||||
|
const bool aligned_N,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@@ -976,7 +976,7 @@ METAL_FUNC void adjust_matrix_offsets(
|
|||||||
y += tid.z * output_stride;
|
y += tid.z * output_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int D, bool batched, typename S>
|
template <typename T, int group_size, typename S, int D, bool batched>
|
||||||
[[kernel]] void mxfp4_qmv_quad(
|
[[kernel]] void mxfp4_qmv_quad(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device S* scales,
|
||||||
@@ -1014,7 +1014,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
|||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_quad_impl<T, group_size, D>(
|
mxfp4_qmv_quad_impl<T, group_size, S, D>(
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
x,
|
x,
|
||||||
@@ -1029,7 +1029,7 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
|||||||
lut);
|
lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, bool batched, typename S>
|
template <typename T, int group_size, typename S, bool batched>
|
||||||
[[kernel]] void mxfp4_qmv_fast(
|
[[kernel]] void mxfp4_qmv_fast(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device S* scales,
|
||||||
@@ -1069,7 +1069,7 @@ template <typename T, int group_size, bool batched, typename S>
|
|||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, bool batched, typename S>
|
template <typename T, const int group_size, typename S, bool batched>
|
||||||
[[kernel]] void mxfp4_qmv(
|
[[kernel]] void mxfp4_qmv(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device S* scales,
|
||||||
@@ -1109,7 +1109,7 @@ template <typename T, const int group_size, bool batched, typename S>
|
|||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, bool batched, typename S>
|
template <typename T, const int group_size, typename S, bool batched>
|
||||||
[[kernel]] void mxfp4_qvm(
|
[[kernel]] void mxfp4_qvm(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device S* scales,
|
||||||
@@ -1149,7 +1149,7 @@ template <typename T, const int group_size, bool batched, typename S>
|
|||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, int split_k = 32, typename S>
|
template <typename T, const int group_size, typename S, int split_k = 32>
|
||||||
[[kernel]] void mxfp4_qvm_split_k(
|
[[kernel]] void mxfp4_qvm_split_k(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device S* scales,
|
||||||
@@ -1205,9 +1205,9 @@ template <typename T, const int group_size, int split_k = 32, typename S>
|
|||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
|
typename S,
|
||||||
const bool aligned_N,
|
const bool aligned_N,
|
||||||
const bool batched,
|
const bool batched,
|
||||||
typename S,
|
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@@ -1254,15 +1254,15 @@ template <
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const bool batched,
|
|
||||||
typename S,
|
typename S,
|
||||||
|
const bool batched,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@@ -1468,8 +1468,8 @@ template <typename T, int group_size, typename S>
|
|||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const bool aligned_N,
|
|
||||||
typename S,
|
typename S,
|
||||||
|
const bool aligned_N,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@@ -1526,7 +1526,7 @@ template <
|
|||||||
w_strides,
|
w_strides,
|
||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -20,8 +20,8 @@
|
|||||||
name, \
|
name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
batched, \
|
uint8_t, \
|
||||||
uint8_t)
|
batched)
|
||||||
|
|
||||||
#define instantiate_quantized_aligned(name, type, aligned) \
|
#define instantiate_quantized_aligned(name, type, aligned) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
@@ -29,8 +29,8 @@
|
|||||||
name, \
|
name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
aligned, \
|
uint8_t, \
|
||||||
uint8_t)
|
aligned)
|
||||||
|
|
||||||
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
@@ -38,9 +38,9 @@
|
|||||||
name, \
|
name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
|
uint8_t, \
|
||||||
aligned, \
|
aligned, \
|
||||||
batched, \
|
batched)
|
||||||
uint8_t)
|
|
||||||
|
|
||||||
#define instantiate_quantized_quad(name, type, D, batched) \
|
#define instantiate_quantized_quad(name, type, D, batched) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
@@ -48,9 +48,9 @@
|
|||||||
name, \
|
name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
|
uint8_t, \
|
||||||
D, \
|
D, \
|
||||||
batched, \
|
batched)
|
||||||
uint8_t)
|
|
||||||
|
|
||||||
#define instantiate_quantized_split_k(name, type, split_k) \
|
#define instantiate_quantized_split_k(name, type, split_k) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
@@ -58,8 +58,8 @@
|
|||||||
name, \
|
name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
split_k, \
|
uint8_t, \
|
||||||
uint8_t)
|
split_k)
|
||||||
|
|
||||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
|
@@ -283,6 +283,7 @@ MTL::ComputePipelineState* get_fft_kernel(
|
|||||||
MTL::ComputePipelineState* get_quantized_kernel(
|
MTL::ComputePipelineState* get_quantized_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
const std::string&,
|
||||||
const std::string&) {
|
const std::string&) {
|
||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
@@ -295,6 +296,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
const array&,
|
const array&,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
|
const std::string&,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
|
@@ -15,6 +15,28 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
auto get_quantized_kernel_wrapped(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& name,
|
||||||
|
const std::string& func,
|
||||||
|
const std::string& mode,
|
||||||
|
const std::string& type,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
Args... args) {
|
||||||
|
std::string template_def;
|
||||||
|
auto fname = mode + "_" + func;
|
||||||
|
if (mode == "affine") {
|
||||||
|
template_def = get_template_definition(
|
||||||
|
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||||
|
} else {
|
||||||
|
template_def = get_template_definition(
|
||||||
|
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
return get_quantized_kernel(d, name, template_def, mode);
|
||||||
|
}
|
||||||
|
|
||||||
inline array
|
inline array
|
||||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
if (!x.flags().row_contiguous) {
|
||||||
@@ -178,10 +200,8 @@ void qmv_quad(
|
|||||||
"_d_",
|
"_d_",
|
||||||
K,
|
K,
|
||||||
B > 1 ? "_batch_1" : "_batch_0");
|
B > 1 ? "_batch_1" : "_batch_0");
|
||||||
auto template_def = get_template_definition(
|
auto kernel = get_quantized_kernel_wrapped(
|
||||||
kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1);
|
d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1);
|
||||||
|
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -235,15 +255,16 @@ void qmv(
|
|||||||
"_b_",
|
"_b_",
|
||||||
bits,
|
bits,
|
||||||
B > 1 ? "_batch_1" : "_batch_0");
|
B > 1 ? "_batch_1" : "_batch_0");
|
||||||
auto template_def = get_template_definition(
|
auto kernel = get_quantized_kernel_wrapped(
|
||||||
|
d,
|
||||||
kname,
|
kname,
|
||||||
mode + (fast ? "_qmv_fast" : "_qmv"),
|
(fast ? "qmv_fast" : "qmv"),
|
||||||
|
mode,
|
||||||
type_string,
|
type_string,
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
B > 1);
|
B > 1);
|
||||||
|
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -337,11 +358,11 @@ void qvm_split_k(
|
|||||||
bits,
|
bits,
|
||||||
"_spk_",
|
"_spk_",
|
||||||
split_k);
|
split_k);
|
||||||
auto template_def = get_template_definition(
|
|
||||||
kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k);
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
auto kernel = get_quantized_kernel_wrapped(
|
||||||
|
d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k);
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -414,10 +435,8 @@ void qvm(
|
|||||||
"_b_",
|
"_b_",
|
||||||
bits,
|
bits,
|
||||||
B > 1 ? "_batch_1" : "_batch_0");
|
B > 1 ? "_batch_1" : "_batch_0");
|
||||||
auto template_def = get_template_definition(
|
auto kernel = get_quantized_kernel_wrapped(
|
||||||
kname, mode + "_qvm", type_string, group_size, bits, B > 1);
|
d, kname, "qvm", mode, type_string, group_size, bits, B > 1);
|
||||||
|
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -476,21 +495,22 @@ void qmm(
|
|||||||
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
|
transpose ? (aligned ? "_alN_true" : "_alN_false") : "",
|
||||||
batched ? "_batch_1" : "_batch_0");
|
batched ? "_batch_1" : "_batch_0");
|
||||||
std::string template_def;
|
std::string template_def;
|
||||||
|
MTL::ComputePipelineState* kernel;
|
||||||
if (transpose) {
|
if (transpose) {
|
||||||
template_def = get_template_definition(
|
kernel = get_quantized_kernel_wrapped(
|
||||||
|
d,
|
||||||
kname,
|
kname,
|
||||||
mode + "_qmm_t",
|
"qmm_t",
|
||||||
|
mode,
|
||||||
type_string,
|
type_string,
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
aligned,
|
aligned,
|
||||||
batched);
|
batched);
|
||||||
} else {
|
} else {
|
||||||
template_def = get_template_definition(
|
kernel = get_quantized_kernel_wrapped(
|
||||||
kname, mode + "_qmm_n", type_string, group_size, bits, batched);
|
d, kname, "qmm_n", mode, type_string, group_size, bits, batched);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -539,7 +559,6 @@ void gather_qmm(
|
|||||||
std::string kname;
|
std::string kname;
|
||||||
kname.reserve(64);
|
kname.reserve(64);
|
||||||
bool aligned = N % 32 == 0;
|
bool aligned = N % 32 == 0;
|
||||||
bool batched = B > 1;
|
|
||||||
std::string type_string = get_type_string(x.dtype());
|
std::string type_string = get_type_string(x.dtype());
|
||||||
concatenate(
|
concatenate(
|
||||||
kname,
|
kname,
|
||||||
@@ -550,16 +569,15 @@ void gather_qmm(
|
|||||||
"_b_",
|
"_b_",
|
||||||
bits,
|
bits,
|
||||||
transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
|
transpose ? (aligned ? "_alN_true" : "_alN_false") : "");
|
||||||
std::string template_def;
|
MTL::ComputePipelineState* kernel;
|
||||||
if (transpose) {
|
if (transpose) {
|
||||||
template_def = get_template_definition(
|
kernel = get_quantized_kernel_wrapped(
|
||||||
kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned);
|
d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned);
|
||||||
} else {
|
} else {
|
||||||
template_def = get_template_definition(
|
kernel = get_quantized_kernel_wrapped(
|
||||||
kname, mode + "_gather_qmm_n", type_string, group_size, bits);
|
d, kname, "gather_qmm_n", mode, type_string, group_size, bits);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -617,14 +635,16 @@ void gather_qmv(
|
|||||||
group_size,
|
group_size,
|
||||||
"_b_",
|
"_b_",
|
||||||
bits);
|
bits);
|
||||||
auto template_def = get_template_definition(
|
|
||||||
|
auto kernel = get_quantized_kernel_wrapped(
|
||||||
|
d,
|
||||||
kname,
|
kname,
|
||||||
mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"),
|
(fast ? "gather_qmv_fast" : "gather_qmv"),
|
||||||
|
mode,
|
||||||
type_string,
|
type_string,
|
||||||
group_size,
|
group_size,
|
||||||
bits);
|
bits);
|
||||||
|
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -680,10 +700,8 @@ void gather_qvm(
|
|||||||
group_size,
|
group_size,
|
||||||
"_b_",
|
"_b_",
|
||||||
bits);
|
bits);
|
||||||
auto template_def = get_template_definition(
|
auto kernel = get_quantized_kernel_wrapped(
|
||||||
kname, mode + "_gather_qvm", type_string, group_size, bits);
|
d, kname, "gather_qvm", mode, type_string, group_size, bits);
|
||||||
|
|
||||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
@@ -806,6 +824,7 @@ void gather_qmm_rhs(
|
|||||||
x,
|
x,
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
|
mode,
|
||||||
bm,
|
bm,
|
||||||
bn,
|
bn,
|
||||||
bk,
|
bk,
|
||||||
@@ -1039,15 +1058,27 @@ void fast::Quantize::eval_gpu(
|
|||||||
compute_encoder.set_output_array(biases, 3);
|
compute_encoder.set_output_array(biases, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostringstream kname;
|
|
||||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||||
: get_type_string(w_pre.dtype());
|
: get_type_string(w_pre.dtype());
|
||||||
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
|
std::string kname;
|
||||||
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
concatenate(
|
||||||
<< bits_;
|
kname,
|
||||||
auto template_def = get_template_definition(
|
dequantize_ ? "affine_dequantize" : "affine_quantize",
|
||||||
kname.str(), kernel_func, type_string, group_size_, bits_);
|
"_",
|
||||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
type_string,
|
||||||
|
"_gs_",
|
||||||
|
group_size_,
|
||||||
|
"_b_",
|
||||||
|
bits_);
|
||||||
|
auto kernel = get_quantized_kernel_wrapped(
|
||||||
|
d,
|
||||||
|
kname,
|
||||||
|
dequantize_ ? "dequantize" : "quantize",
|
||||||
|
"affine",
|
||||||
|
type_string,
|
||||||
|
group_size_,
|
||||||
|
bits_);
|
||||||
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Treat uint32 as uint8 in kernel
|
// Treat uint32 as uint8 in kernel
|
||||||
|
Reference in New Issue
Block a user