This commit is contained in:
Awni Hannun
2025-10-24 10:56:11 -07:00
parent aa44e3c45d
commit 2f3b5c9c5c
7 changed files with 172 additions and 195 deletions

View File

@@ -81,7 +81,8 @@ if(MLX_METAL_JIT)
make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h)
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
kernels/fp4.h)
make_jit_source(gemv_masked)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -24,7 +24,7 @@ const char* hadamard();
const char* logsumexp();
const char* quantized_utils();
const char* quantized();
const char* fp4_quantized();
const char* fp_quantized();
const char* ternary();
const char* scan();
const char* scatter_axis();

View File

@@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
metal::utils(),
metal::gemm(),
metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
(mode == "affine") ? metal::quantized() : metal::fp_quantized(),
template_def);
return kernel_source;
});
@@ -856,39 +856,22 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
std::string kernel_source;
concatenate(
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
if (mode == "affine") {
concatenate(
kernel_source,
metal::quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
bool is_affine = mode == "affine";
concatenate(
kernel_source,
is_affine ? metal::quantized() : metal::fp_quantized(),
get_template_definition(
lib_name,
(is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
get_type_string(x.dtype()),
group_size,
bits,
bm,
bn,
bk,
wm,
wn,
transpose));
return kernel_source;
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -110,7 +110,8 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_col.h
reduction/reduce_row.h)
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp_quantized fp_quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)

View File

@@ -3,6 +3,9 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/fp4.h"
#include "mlx/backend/metal/kernels/fp8.h"
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
@@ -60,7 +63,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
}
template <typename T>
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
}
@@ -137,8 +140,7 @@ template <
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
typename S>
short group_size>
struct QuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
@@ -165,12 +167,12 @@ struct QuantizedBlockLoader {
threadgroup T* dst;
const device uint8_t* src;
const device S* scales;
const device uint8_t* scales;
threadgroup T* lut;
QuantizedBlockLoader(
const device uint8_t* src_,
const device S* scales_,
const device uint8_t* scales_,
const int src_ld_,
threadgroup T* dst_,
threadgroup T* lut_,
@@ -190,7 +192,7 @@ struct QuantizedBlockLoader {
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size),
lut(lut_) {
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
load_fp4_lut(lut, simd_group_id, simd_lane_id);
}
void load_unsafe() const {
@@ -252,10 +254,10 @@ struct QuantizedBlockLoader {
}
};
template <typename T, int group_size, typename S, int D>
METAL_FUNC void mxfp4_qmv_quad_impl(
template <typename T, int group_size, int bits, int D>
METAL_FUNC void fp_qmv_quad_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
constant int& in_vec_size,
@@ -277,7 +279,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
@@ -293,7 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
for (int row = 0; row < results_per_quadgroup; row++) {
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) {
@@ -309,10 +311,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_fast_impl(
template <typename T, int group_size, int bits>
METAL_FUNC void fp_qmv_fast_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -335,7 +337,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -372,10 +374,10 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_impl(
template <typename T, int group_size, int bits>
METAL_FUNC void fp_qmv_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -400,7 +402,7 @@ METAL_FUNC void mxfp4_qmv_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -430,7 +432,7 @@ METAL_FUNC void mxfp4_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
S s = sl[0];
uint8_t s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
@@ -511,10 +513,10 @@ METAL_FUNC void mxfp4_qmv_impl(
}
}
template <typename T, const int group_size, typename S>
METAL_FUNC void mxfp4_qvm_impl(
template <typename T, const int group_size, int bits>
METAL_FUNC void fp_qvm_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const int in_vec_size,
@@ -543,7 +545,7 @@ METAL_FUNC void mxfp4_qvm_impl(
thread U scale = 0;
thread U x_local = 0;
load_mxfp4_lut(lut, simd_gid, simd_lid);
load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
@@ -615,14 +617,14 @@ METAL_FUNC void mxfp4_qvm_impl(
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_t_impl(
METAL_FUNC void fp_qmm_t_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
@@ -659,8 +661,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
// Set the block
const int K_w = K * bytes_per_pack / pack_factor;
@@ -741,13 +742,13 @@ METAL_FUNC void mxfp4_qmm_t_impl(
template <
typename T,
const int group_size,
typename S,
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_n_impl(
METAL_FUNC void fp_qmm_n_impl(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
@@ -785,8 +786,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
BN_padded,
0,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
auto wl = (const device uint8_t*)w;
@@ -873,11 +873,11 @@ METAL_FUNC void mxfp4_qmm_n_impl(
}
}
template <typename T, typename S>
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
const device uint8_t*& scales,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
@@ -908,11 +908,11 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, typename S>
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
const device uint8_t*& scales,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T*& y,
@@ -958,10 +958,10 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, typename S, int D, bool batched>
[[kernel]] void mxfp4_qmv_quad(
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void fp_qmv_quad(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -996,7 +996,7 @@ template <typename T, int group_size, typename S, int D, bool batched>
tid);
}
threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, S, D>(
fp_qmv_quad_impl<T, group_size, bits, D>(
w,
scales,
x,
@@ -1011,10 +1011,10 @@ template <typename T, int group_size, typename S, int D, bool batched>
lut);
}
template <typename T, int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv_fast(
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void fp_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1047,14 +1047,14 @@ template <typename T, int group_size, typename S, bool batched>
tid);
}
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv(
template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void fp_qmv(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1087,14 +1087,14 @@ template <typename T, const int group_size, typename S, bool batched>
tid);
}
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qvm(
template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void fp_qvm(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1127,14 +1127,14 @@ template <typename T, const int group_size, typename S, bool batched>
tid);
}
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, typename S, int split_k = 32>
[[kernel]] void mxfp4_qvm_split_k(
template <typename T, const int group_size, int bits, int split_k = 32>
[[kernel]] void fp_qvm_split_k(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
@@ -1171,7 +1171,7 @@ template <typename T, const int group_size, typename S, int split_k = 32>
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
fp_qvm_impl<T, group_size, bits>(
w,
scales,
x,
@@ -1187,15 +1187,15 @@ template <typename T, const int group_size, typename S, int split_k = 32>
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_t(
[[kernel]] void fp_qmm_t(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& K,
@@ -1236,21 +1236,21 @@ template <
s_strides,
tid);
}
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_n(
[[kernel]] void fp_qmm_n(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
device T* y,
const constant int& K,
@@ -1293,14 +1293,14 @@ template <
tid);
}
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv_fast(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1343,14 +1343,14 @@ template <typename T, int group_size, typename S>
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qmv(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1393,14 +1393,14 @@ template <typename T, int group_size, typename S>
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qvm(
template <typename T, int group_size, int bits>
[[kernel]] void fp_gather_qvm(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1443,21 +1443,21 @@ template <typename T, int group_size, typename S>
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_t(
[[kernel]] void fp_gather_qmm_t(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1508,20 +1508,20 @@ template <
w_strides,
s_strides,
tid);
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_n(
[[kernel]] void fp_gather_qmm_n(
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
@@ -1573,24 +1573,24 @@ template <
w_strides,
s_strides,
tid);
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
int group_size,
typename S,
int bits,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void mxfp4_gather_qmm_rhs(
[[kernel]] void fp_gather_qmm_rhs(
const device T* x,
const device uint32_t* w,
const device S* scales,
const device uint8_t* scales,
const device uint32_t* indices,
device T* y,
const constant int& M,
@@ -1626,8 +1626,7 @@ template <
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
S>;
group_size>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
@@ -1775,7 +1774,7 @@ template <
template <int bits>
struct Quantize {
uint8_t operator()(float x) {
if constexpr (bits == 8) {
if (bits == 8) {
return fp8_e4m3(x).bits;
} else {
return fp4_e2m1(x).bits;
@@ -1786,7 +1785,7 @@ struct Quantize {
template <int bits>
struct Dequantize {
float operator()(uint8_t x) {
if constexpr (bits == 8) {
if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x));
} else {
return float(*(thread fp4_e2m1*)(&x));

View File

@@ -4,63 +4,61 @@
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp8.h"
#include "mlx/backend/metal/kernels/fp4.h"
#include "mlx/backend/metal/kernels/fp_quantized.h"
#define instantiate_quantized(name, type) \
#define instantiate_quantized(mode, name, type) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \
name, \
#mode "_" #name "_" #type "_gs_32_b_4", \
fp_ ## name, \
type, \
32, \
uint8_t)
4)
#define instantiate_quantized_batched(name, type, batched) \
#define instantiate_quantized_batched(mode, name, type, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \
#mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
fp_ ## name, \
type, \
32, \
uint8_t, \
4, \
batched)
#define instantiate_quantized_aligned(name, type, aligned) \
#define instantiate_quantized_aligned(mode, name, type, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
fp_ ## name, \
type, \
32, \
uint8_t, \
4, \
aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
uint8_t, \
4, \
aligned, \
batched)
#define instantiate_quantized_quad(name, type, D, batched) \
#define instantiate_quantized_quad(mode, name, type, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \
#mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
uint8_t, \
4, \
D, \
batched)
#define instantiate_quantized_split_k(name, type, split_k) \
#define instantiate_quantized_split_k(mode, name, type, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \
#mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
fp_ ## name, \
type, \
32, \
uint8_t, \
4, \
split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
@@ -69,7 +67,7 @@
func, \
type, \
32, \
uint8_t, \
4, \
bm, \
bn, \
bk, \
@@ -77,62 +75,62 @@
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type) \
instantiate_quantized_batched(name, type, 1) \
instantiate_quantized_batched(name, type, 0)
#define instantiate_quantized_batched_wrap(mode, name, type) \
instantiate_quantized_batched(mode, name, type, 1) \
instantiate_quantized_batched(mode, name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type)
instantiate_quantized(mxfp4, gather_qmv_fast, type) \
instantiate_quantized(mxfp4, gather_qmv, type) \
instantiate_quantized(mxfp4, gather_qvm, type) \
instantiate_quantized(mxfp4, gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
instantiate_kernel( \
mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
#mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_quantize, \
type, \
group_size, \
bits) \
instantiate_kernel( \
mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
#mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_dequantize, \
type, \
group_size, \
bits)
#define instantiate_quantize_dequantize_modes(type) \
instantiate_quantize_dequantize(type, "mxfp4", 32, 4) \
instantiate_quantize_dequantize(type, "nvfp4", 16, 4) \
instantiate_quantize_dequantize(type, "mxfp8", 32, 8)
instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
instantiate_quantize_dequantize(type, mxfp8, 32, 8)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \

View File

@@ -27,14 +27,9 @@ auto get_quantized_kernel_wrapped(
int bits,
Args... args) {
std::string template_def;
auto fname = mode + "_" + func;
if (mode == "affine") {
template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...);
} else {
template_def = get_template_definition(
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
}
std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...);
return get_quantized_kernel(d, name, template_def, mode);
}