mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix jit
This commit is contained in:
@@ -81,7 +81,8 @@ if(MLX_METAL_JIT)
|
||||
|
||||
make_jit_source(quantized_utils)
|
||||
make_jit_source(quantized kernels/quantized_utils.h)
|
||||
make_jit_source(fp4_quantized kernels/quantized_utils.h)
|
||||
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
|
||||
kernels/fp4.h)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||
|
||||
@@ -24,7 +24,7 @@ const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized_utils();
|
||||
const char* quantized();
|
||||
const char* fp4_quantized();
|
||||
const char* fp_quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* scatter_axis();
|
||||
|
||||
@@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized_utils(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp_quantized(),
|
||||
template_def);
|
||||
return kernel_source;
|
||||
});
|
||||
@@ -856,39 +856,22 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
||||
if (mode == "affine") {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
} else {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::fp4_quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
"uint8_t",
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
}
|
||||
bool is_affine = mode == "affine";
|
||||
concatenate(
|
||||
kernel_source,
|
||||
is_affine ? metal::quantized() : metal::fp_quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
(is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
|
||||
@@ -110,7 +110,8 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(fp_quantized fp_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
|
||||
${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fp4.h"
|
||||
#include "mlx/backend/metal/kernels/fp8.h"
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
@@ -60,7 +63,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
if (simd_gid == 0 && simd_lid < 16) {
|
||||
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
|
||||
}
|
||||
@@ -137,8 +140,7 @@ template <
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
typename S>
|
||||
short group_size>
|
||||
struct QuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
@@ -165,12 +167,12 @@ struct QuantizedBlockLoader {
|
||||
|
||||
threadgroup T* dst;
|
||||
const device uint8_t* src;
|
||||
const device S* scales;
|
||||
const device uint8_t* scales;
|
||||
threadgroup T* lut;
|
||||
|
||||
QuantizedBlockLoader(
|
||||
const device uint8_t* src_,
|
||||
const device S* scales_,
|
||||
const device uint8_t* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
threadgroup T* lut_,
|
||||
@@ -190,7 +192,7 @@ struct QuantizedBlockLoader {
|
||||
bj * bytes_per_pack),
|
||||
scales(scales_ + bi * src_ld / group_size),
|
||||
lut(lut_) {
|
||||
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
load_fp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
}
|
||||
|
||||
void load_unsafe() const {
|
||||
@@ -252,10 +254,10 @@ struct QuantizedBlockLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int group_size, typename S, int D>
|
||||
METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
template <typename T, int group_size, int bits, int D>
|
||||
METAL_FUNC void fp_qmv_quad_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
constant int& in_vec_size,
|
||||
@@ -277,7 +279,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_quadgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
@@ -293,7 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
|
||||
for (int row = 0; row < results_per_quadgroup; row++) {
|
||||
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||
@@ -309,10 +311,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void fp_qmv_fast_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -335,7 +337,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
typedef float U;
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -372,10 +374,10 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qmv_impl(
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void fp_qmv_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -400,7 +402,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -430,7 +432,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
S s = sl[0];
|
||||
uint8_t s = sl[0];
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
|
||||
@@ -511,10 +513,10 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S>
|
||||
METAL_FUNC void mxfp4_qvm_impl(
|
||||
template <typename T, const int group_size, int bits>
|
||||
METAL_FUNC void fp_qvm_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const int in_vec_size,
|
||||
@@ -543,7 +545,7 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
thread U scale = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||
@@ -615,14 +617,14 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
METAL_FUNC void fp_qmm_t_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
@@ -659,8 +661,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
S>;
|
||||
group_size>;
|
||||
|
||||
// Set the block
|
||||
const int K_w = K * bytes_per_pack / pack_factor;
|
||||
@@ -741,13 +742,13 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
METAL_FUNC void fp_qmm_n_impl(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
@@ -785,8 +786,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
BN_padded,
|
||||
0,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
S>;
|
||||
group_size>;
|
||||
|
||||
auto wl = (const device uint8_t*)w;
|
||||
|
||||
@@ -873,11 +873,11 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device S*& scales,
|
||||
const device uint8_t*& scales,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
@@ -908,11 +908,11 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device S*& scales,
|
||||
const device uint8_t*& scales,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
device T*& y,
|
||||
@@ -958,10 +958,10 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S, int D, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_quad(
|
||||
template <typename T, int group_size, int bits, int D, bool batched>
|
||||
[[kernel]] void fp_qmv_quad(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -996,7 +996,7 @@ template <typename T, int group_size, typename S, int D, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_quad_impl<T, group_size, S, D>(
|
||||
fp_qmv_quad_impl<T, group_size, bits, D>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
@@ -1011,10 +1011,10 @@ template <typename T, int group_size, typename S, int D, bool batched>
|
||||
lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv_fast(
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qmv_fast(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1047,14 +1047,14 @@ template <typename T, int group_size, typename S, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qmv(
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qmv(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1087,14 +1087,14 @@ template <typename T, const int group_size, typename S, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S, bool batched>
|
||||
[[kernel]] void mxfp4_qvm(
|
||||
template <typename T, const int group_size, int bits, bool batched>
|
||||
[[kernel]] void fp_qvm(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1127,14 +1127,14 @@ template <typename T, const int group_size, typename S, bool batched>
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
[[kernel]] void mxfp4_qvm_split_k(
|
||||
template <typename T, const int group_size, int bits, int split_k = 32>
|
||||
[[kernel]] void fp_qvm_split_k(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
@@ -1171,7 +1171,7 @@ template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
@@ -1187,15 +1187,15 @@ template <typename T, const int group_size, typename S, int split_k = 32>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_qmm_t(
|
||||
[[kernel]] void fp_qmm_t(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& K,
|
||||
@@ -1236,21 +1236,21 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_qmm_n(
|
||||
[[kernel]] void fp_qmm_n(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& K,
|
||||
@@ -1293,14 +1293,14 @@ template <
|
||||
tid);
|
||||
}
|
||||
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qmv_fast(
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qmv_fast(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1343,14 +1343,14 @@ template <typename T, int group_size, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
fp_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qmv(
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qmv(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1393,14 +1393,14 @@ template <typename T, int group_size, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
fp_qmv_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
[[kernel]] void mxfp4_gather_qvm(
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void fp_gather_qvm(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1443,21 +1443,21 @@ template <typename T, int group_size, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
fp_qvm_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_gather_qmm_t(
|
||||
[[kernel]] void fp_gather_qmm_t(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1508,20 +1508,20 @@ template <
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
||||
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
typename S,
|
||||
const int bits,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void mxfp4_gather_qmm_n(
|
||||
[[kernel]] void fp_gather_qmm_n(
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device T* x,
|
||||
const device uint32_t* lhs_indices,
|
||||
const device uint32_t* rhs_indices,
|
||||
@@ -1573,24 +1573,24 @@ template <
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
typename S,
|
||||
int bits,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose>
|
||||
[[kernel]] void mxfp4_gather_qmm_rhs(
|
||||
[[kernel]] void fp_gather_qmm_rhs(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
const device S* scales,
|
||||
const device uint8_t* scales,
|
||||
const device uint32_t* indices,
|
||||
device T* y,
|
||||
const constant int& M,
|
||||
@@ -1626,8 +1626,7 @@ template <
|
||||
transpose ? BK_padded : BN_padded,
|
||||
transpose,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
S>;
|
||||
group_size>;
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||
@@ -1775,7 +1774,7 @@ template <
|
||||
template <int bits>
|
||||
struct Quantize {
|
||||
uint8_t operator()(float x) {
|
||||
if constexpr (bits == 8) {
|
||||
if (bits == 8) {
|
||||
return fp8_e4m3(x).bits;
|
||||
} else {
|
||||
return fp4_e2m1(x).bits;
|
||||
@@ -1786,7 +1785,7 @@ struct Quantize {
|
||||
template <int bits>
|
||||
struct Dequantize {
|
||||
float operator()(uint8_t x) {
|
||||
if constexpr (bits == 8) {
|
||||
if (bits == 8) {
|
||||
return float(*(thread fp8_e4m3*)(&x));
|
||||
} else {
|
||||
return float(*(thread fp4_e2m1*)(&x));
|
||||
|
||||
@@ -4,63 +4,61 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||
#include "mlx/backend/metal/kernels/fp8.h"
|
||||
#include "mlx/backend/metal/kernels/fp4.h"
|
||||
#include "mlx/backend/metal/kernels/fp_quantized.h"
|
||||
|
||||
#define instantiate_quantized(name, type) \
|
||||
#define instantiate_quantized(mode, name, type) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4", \
|
||||
name, \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4", \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t)
|
||||
4)
|
||||
|
||||
#define instantiate_quantized_batched(name, type, batched) \
|
||||
#define instantiate_quantized_batched(mode, name, type, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||
name, \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
4, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, aligned) \
|
||||
#define instantiate_quantized_aligned(mode, name, type, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||
name, \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
4, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
||||
#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
4, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_quad(name, type, D, batched) \
|
||||
#define instantiate_quantized_quad(mode, name, type, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||
name, \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
4, \
|
||||
D, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, split_k) \
|
||||
#define instantiate_quantized_split_k(mode, name, type, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||
name, \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
4, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
@@ -69,7 +67,7 @@
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
4, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
@@ -77,62 +75,62 @@
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type) \
|
||||
instantiate_quantized_batched(name, type, 1) \
|
||||
instantiate_quantized_batched(name, type, 0)
|
||||
#define instantiate_quantized_batched_wrap(mode, name, type) \
|
||||
instantiate_quantized_batched(mode, name, type, 1) \
|
||||
instantiate_quantized_batched(mode, name, type, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv, type) \
|
||||
instantiate_quantized(mxfp4_gather_qvm, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmm_n, type)
|
||||
instantiate_quantized(mxfp4, gather_qmv_fast, type) \
|
||||
instantiate_quantized(mxfp4, gather_qmv, type) \
|
||||
instantiate_quantized(mxfp4, gather_qvm, type) \
|
||||
instantiate_quantized(mxfp4, gather_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
|
||||
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
|
||||
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
|
||||
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
#mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
fp_quantize, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits) \
|
||||
instantiate_kernel( \
|
||||
mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
#mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
fp_dequantize, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_quantize_dequantize_modes(type) \
|
||||
instantiate_quantize_dequantize(type, "mxfp4", 32, 4) \
|
||||
instantiate_quantize_dequantize(type, "nvfp4", 16, 4) \
|
||||
instantiate_quantize_dequantize(type, "mxfp8", 32, 8)
|
||||
instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
|
||||
instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
|
||||
instantiate_quantize_dequantize(type, mxfp8, 32, 8)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_batched(type) \
|
||||
|
||||
@@ -27,14 +27,9 @@ auto get_quantized_kernel_wrapped(
|
||||
int bits,
|
||||
Args... args) {
|
||||
std::string template_def;
|
||||
auto fname = mode + "_" + func;
|
||||
if (mode == "affine") {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
|
||||
}
|
||||
std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
|
||||
template_def = get_template_definition(
|
||||
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||
return get_quantized_kernel(d, name, template_def, mode);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user