This commit is contained in:
Awni Hannun
2025-10-24 10:56:11 -07:00
parent aa44e3c45d
commit 2f3b5c9c5c
7 changed files with 172 additions and 195 deletions

View File

@@ -81,7 +81,8 @@ if(MLX_METAL_JIT)
make_jit_source(quantized_utils) make_jit_source(quantized_utils)
make_jit_source(quantized kernels/quantized_utils.h) make_jit_source(quantized kernels/quantized_utils.h)
make_jit_source(fp4_quantized kernels/quantized_utils.h) make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
kernels/fp4.h)
make_jit_source(gemv_masked) make_jit_source(gemv_masked)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)

View File

@@ -24,7 +24,7 @@ const char* hadamard();
const char* logsumexp(); const char* logsumexp();
const char* quantized_utils(); const char* quantized_utils();
const char* quantized(); const char* quantized();
const char* fp4_quantized(); const char* fp_quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
const char* scatter_axis(); const char* scatter_axis();

View File

@@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
metal::utils(), metal::utils(),
metal::gemm(), metal::gemm(),
metal::quantized_utils(), metal::quantized_utils(),
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(), (mode == "affine") ? metal::quantized() : metal::fp_quantized(),
template_def); template_def);
return kernel_source; return kernel_source;
}); });
@@ -856,13 +856,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
std::string kernel_source; std::string kernel_source;
concatenate( concatenate(
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm()); kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
if (mode == "affine") { bool is_affine = mode == "affine";
concatenate( concatenate(
kernel_source, kernel_source,
metal::quantized(), is_affine ? metal::quantized() : metal::fp_quantized(),
get_template_definition( get_template_definition(
lib_name, lib_name,
mode + "_gather_qmm_rhs", (is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
get_type_string(x.dtype()), get_type_string(x.dtype()),
group_size, group_size,
bits, bits,
@@ -872,23 +872,6 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
wm, wm,
wn, wn,
transpose)); transpose));
} else {
concatenate(
kernel_source,
metal::fp4_quantized(),
get_template_definition(
lib_name,
mode + "_gather_qmm_rhs",
get_type_string(x.dtype()),
group_size,
"uint8_t",
bm,
bn,
bk,
wm,
wn,
transpose));
}
return kernel_source; return kernel_source;
}); });
return d.get_kernel(kernel_name, lib, hash_name, func_consts); return d.get_kernel(kernel_name, lib, hash_name, func_consts);

View File

@@ -110,7 +110,8 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_col.h reduction/reduce_col.h
reduction/reduce_row.h) reduction/reduce_row.h)
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
build_kernel(fp_quantized fp_quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
${STEEL_HEADERS})
build_kernel(scan scan.h) build_kernel(scan scan.h)
build_kernel(softmax softmax.h) build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h) build_kernel(logsumexp logsumexp.h)

View File

@@ -3,6 +3,9 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/fp4.h"
#include "mlx/backend/metal/kernels/fp8.h"
constant bool align_M [[function_constant(200)]]; constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]]; constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]]; constant bool align_K [[function_constant(202)]];
@@ -60,7 +63,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
} }
template <typename T> template <typename T>
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) { if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]); lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
} }
@@ -137,8 +140,7 @@ template <
short dst_ld, short dst_ld,
short reduction_dim, short reduction_dim,
short tgp_size, short tgp_size,
short group_size, short group_size>
typename S>
struct QuantizedBlockLoader { struct QuantizedBlockLoader {
static_assert( static_assert(
BCOLS <= group_size, BCOLS <= group_size,
@@ -165,12 +167,12 @@ struct QuantizedBlockLoader {
threadgroup T* dst; threadgroup T* dst;
const device uint8_t* src; const device uint8_t* src;
const device S* scales; const device uint8_t* scales;
threadgroup T* lut; threadgroup T* lut;
QuantizedBlockLoader( QuantizedBlockLoader(
const device uint8_t* src_, const device uint8_t* src_,
const device S* scales_, const device uint8_t* scales_,
const int src_ld_, const int src_ld_,
threadgroup T* dst_, threadgroup T* dst_,
threadgroup T* lut_, threadgroup T* lut_,
@@ -190,7 +192,7 @@ struct QuantizedBlockLoader {
bj * bytes_per_pack), bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size), scales(scales_ + bi * src_ld / group_size),
lut(lut_) { lut(lut_) {
load_mxfp4_lut(lut, simd_group_id, simd_lane_id); load_fp4_lut(lut, simd_group_id, simd_lane_id);
} }
void load_unsafe() const { void load_unsafe() const {
@@ -252,10 +254,10 @@ struct QuantizedBlockLoader {
} }
}; };
template <typename T, int group_size, typename S, int D> template <typename T, int group_size, int bits, int D>
METAL_FUNC void mxfp4_qmv_quad_impl( METAL_FUNC void fp_qmv_quad_impl(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
constant int& in_vec_size, constant int& in_vec_size,
@@ -277,7 +279,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0}; thread U result[results_per_quadgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid); load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_w = in_vec_size / pack_factor;
@@ -293,7 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
for (int row = 0; row < results_per_quadgroup; row++) { for (int row = 0; row < results_per_quadgroup; row++) {
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device S* sl = scales + row * in_vec_size_g * quads_per_simd; const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
U s = dequantize_scale<U>(sl[0]); U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) { if (row * quads_per_simd + out_row < out_vec_size) {
@@ -309,10 +311,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
} }
} }
template <typename T, int group_size, typename S> template <typename T, int group_size, int bits>
METAL_FUNC void mxfp4_qmv_fast_impl( METAL_FUNC void fp_qmv_fast_impl(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
@@ -335,7 +337,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
typedef float U; typedef float U;
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0}; thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid); load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -372,10 +374,10 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
} }
} }
template <typename T, int group_size, typename S> template <typename T, int group_size, int bits>
METAL_FUNC void mxfp4_qmv_impl( METAL_FUNC void fp_qmv_impl(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
@@ -400,7 +402,7 @@ METAL_FUNC void mxfp4_qmv_impl(
thread U x_thread[values_per_thread]; thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0}; thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid); load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@@ -430,7 +432,7 @@ METAL_FUNC void mxfp4_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g; const device auto* sl = scales + row * in_vec_size_g;
S s = sl[0]; uint8_t s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut); result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
} }
@@ -511,10 +513,10 @@ METAL_FUNC void mxfp4_qmv_impl(
} }
} }
template <typename T, const int group_size, typename S> template <typename T, const int group_size, int bits>
METAL_FUNC void mxfp4_qvm_impl( METAL_FUNC void fp_qvm_impl(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const int in_vec_size, const int in_vec_size,
@@ -543,7 +545,7 @@ METAL_FUNC void mxfp4_qvm_impl(
thread U scale = 0; thread U scale = 0;
thread U x_local = 0; thread U x_local = 0;
load_mxfp4_lut(lut, simd_gid, simd_lid); load_fp4_lut(lut, simd_gid, simd_lid);
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
@@ -615,14 +617,14 @@ METAL_FUNC void mxfp4_qvm_impl(
template < template <
typename T, typename T,
const int group_size, const int group_size,
typename S, const int bits,
const bool aligned_N, const bool aligned_N,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
METAL_FUNC void mxfp4_qmm_t_impl( METAL_FUNC void fp_qmm_t_impl(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
threadgroup T* Xs, threadgroup T* Xs,
@@ -659,8 +661,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
BK_padded, BK_padded,
1, 1,
WM * WN * SIMD_SIZE, WM * WN * SIMD_SIZE,
group_size, group_size>;
S>;
// Set the block // Set the block
const int K_w = K * bytes_per_pack / pack_factor; const int K_w = K * bytes_per_pack / pack_factor;
@@ -741,13 +742,13 @@ METAL_FUNC void mxfp4_qmm_t_impl(
template < template <
typename T, typename T,
const int group_size, const int group_size,
typename S, const int bits,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
METAL_FUNC void mxfp4_qmm_n_impl( METAL_FUNC void fp_qmm_n_impl(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
threadgroup T* Xs, threadgroup T* Xs,
@@ -785,8 +786,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
BN_padded, BN_padded,
0, 0,
WM * WN * SIMD_SIZE, WM * WN * SIMD_SIZE,
group_size, group_size>;
S>;
auto wl = (const device uint8_t*)w; auto wl = (const device uint8_t*)w;
@@ -873,11 +873,11 @@ METAL_FUNC void mxfp4_qmm_n_impl(
} }
} }
template <typename T, typename S> template <typename T>
METAL_FUNC void adjust_matrix_offsets( METAL_FUNC void adjust_matrix_offsets(
const device T*& x, const device T*& x,
const device uint32_t*& w, const device uint32_t*& w,
const device S*& scales, const device uint8_t*& scales,
device T*& y, device T*& y,
int output_stride, int output_stride,
const constant int& x_batch_ndims, const constant int& x_batch_ndims,
@@ -908,11 +908,11 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride; y += tid.z * output_stride;
} }
template <typename T, typename S> template <typename T>
METAL_FUNC void adjust_matrix_offsets( METAL_FUNC void adjust_matrix_offsets(
const device T*& x, const device T*& x,
const device uint32_t*& w, const device uint32_t*& w,
const device S*& scales, const device uint8_t*& scales,
const device uint32_t* lhs_indices, const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices, const device uint32_t* rhs_indices,
device T*& y, device T*& y,
@@ -958,10 +958,10 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride; y += tid.z * output_stride;
} }
template <typename T, int group_size, typename S, int D, bool batched> template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void mxfp4_qmv_quad( [[kernel]] void fp_qmv_quad(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
@@ -996,7 +996,7 @@ template <typename T, int group_size, typename S, int D, bool batched>
tid); tid);
} }
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, S, D>( fp_qmv_quad_impl<T, group_size, bits, D>(
w, w,
scales, scales,
x, x,
@@ -1011,10 +1011,10 @@ template <typename T, int group_size, typename S, int D, bool batched>
lut); lut);
} }
template <typename T, int group_size, typename S, bool batched> template <typename T, int group_size, int bits, bool batched>
[[kernel]] void mxfp4_qmv_fast( [[kernel]] void fp_qmv_fast(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
@@ -1047,14 +1047,14 @@ template <typename T, int group_size, typename S, bool batched>
tid); tid);
} }
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>( fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, const int group_size, typename S, bool batched> template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void mxfp4_qmv( [[kernel]] void fp_qmv(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
@@ -1087,14 +1087,14 @@ template <typename T, const int group_size, typename S, bool batched>
tid); tid);
} }
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>( fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, const int group_size, typename S, bool batched> template <typename T, const int group_size, int bits, bool batched>
[[kernel]] void mxfp4_qvm( [[kernel]] void fp_qvm(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
@@ -1127,14 +1127,14 @@ template <typename T, const int group_size, typename S, bool batched>
tid); tid);
} }
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>( fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, const int group_size, typename S, int split_k = 32> template <typename T, const int group_size, int bits, int split_k = 32>
[[kernel]] void mxfp4_qvm_split_k( [[kernel]] void fp_qvm_split_k(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& in_vec_size, const constant int& in_vec_size,
@@ -1171,7 +1171,7 @@ template <typename T, const int group_size, typename S, int split_k = 32>
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>( fp_qvm_impl<T, group_size, bits>(
w, w,
scales, scales,
x, x,
@@ -1187,15 +1187,15 @@ template <typename T, const int group_size, typename S, int split_k = 32>
template < template <
typename T, typename T,
const int group_size, const int group_size,
typename S, const int bits,
const bool aligned_N, const bool aligned_N,
const bool batched, const bool batched,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void mxfp4_qmm_t( [[kernel]] void fp_qmm_t(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& K, const constant int& K,
@@ -1236,21 +1236,21 @@ template <
s_strides, s_strides,
tid); tid);
} }
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>( fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
} }
template < template <
typename T, typename T,
const int group_size, const int group_size,
typename S, const int bits,
const bool batched, const bool batched,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void mxfp4_qmm_n( [[kernel]] void fp_qmm_n(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
device T* y, device T* y,
const constant int& K, const constant int& K,
@@ -1293,14 +1293,14 @@ template <
tid); tid);
} }
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>( fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
} }
template <typename T, int group_size, typename S> template <typename T, int group_size, int bits>
[[kernel]] void mxfp4_gather_qmv_fast( [[kernel]] void fp_gather_qmv_fast(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
const device uint32_t* lhs_indices, const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices, const device uint32_t* rhs_indices,
@@ -1343,14 +1343,14 @@ template <typename T, int group_size, typename S>
s_strides, s_strides,
tid); tid);
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>( fp_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, int group_size, typename S> template <typename T, int group_size, int bits>
[[kernel]] void mxfp4_gather_qmv( [[kernel]] void fp_gather_qmv(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
const device uint32_t* lhs_indices, const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices, const device uint32_t* rhs_indices,
@@ -1393,14 +1393,14 @@ template <typename T, int group_size, typename S>
s_strides, s_strides,
tid); tid);
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>( fp_qmv_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template <typename T, int group_size, typename S> template <typename T, int group_size, int bits>
[[kernel]] void mxfp4_gather_qvm( [[kernel]] void fp_gather_qvm(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
const device uint32_t* lhs_indices, const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices, const device uint32_t* rhs_indices,
@@ -1443,21 +1443,21 @@ template <typename T, int group_size, typename S>
s_strides, s_strides,
tid); tid);
threadgroup float lut[16]; threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>( fp_qvm_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
} }
template < template <
typename T, typename T,
const int group_size, const int group_size,
typename S, const int bits,
const bool aligned_N, const bool aligned_N,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_t( [[kernel]] void fp_gather_qmm_t(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
const device uint32_t* lhs_indices, const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices, const device uint32_t* rhs_indices,
@@ -1508,20 +1508,20 @@ template <
w_strides, w_strides,
s_strides, s_strides,
tid); tid);
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>( fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
} }
template < template <
typename T, typename T,
const int group_size, const int group_size,
typename S, const int bits,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_n( [[kernel]] void fp_gather_qmm_n(
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device T* x, const device T* x,
const device uint32_t* lhs_indices, const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices, const device uint32_t* rhs_indices,
@@ -1573,24 +1573,24 @@ template <
w_strides, w_strides,
s_strides, s_strides,
tid); tid);
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>( fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
} }
template < template <
typename T, typename T,
int group_size, int group_size,
typename S, int bits,
int BM, int BM,
int BN, int BN,
int BK, int BK,
int WM, int WM,
int WN, int WN,
bool transpose> bool transpose>
[[kernel]] void mxfp4_gather_qmm_rhs( [[kernel]] void fp_gather_qmm_rhs(
const device T* x, const device T* x,
const device uint32_t* w, const device uint32_t* w,
const device S* scales, const device uint8_t* scales,
const device uint32_t* indices, const device uint32_t* indices,
device T* y, device T* y,
const constant int& M, const constant int& M,
@@ -1626,8 +1626,7 @@ template <
transpose ? BK_padded : BN_padded, transpose ? BK_padded : BN_padded,
transpose, transpose,
WM * WN * SIMD_SIZE, WM * WN * SIMD_SIZE,
group_size, group_size>;
S>;
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
@@ -1775,7 +1774,7 @@ template <
template <int bits> template <int bits>
struct Quantize { struct Quantize {
uint8_t operator()(float x) { uint8_t operator()(float x) {
if constexpr (bits == 8) { if (bits == 8) {
return fp8_e4m3(x).bits; return fp8_e4m3(x).bits;
} else { } else {
return fp4_e2m1(x).bits; return fp4_e2m1(x).bits;
@@ -1786,7 +1785,7 @@ struct Quantize {
template <int bits> template <int bits>
struct Dequantize { struct Dequantize {
float operator()(uint8_t x) { float operator()(uint8_t x) {
if constexpr (bits == 8) { if (bits == 8) {
return float(*(thread fp8_e4m3*)(&x)); return float(*(thread fp8_e4m3*)(&x));
} else { } else {
return float(*(thread fp4_e2m1*)(&x)); return float(*(thread fp4_e2m1*)(&x));

View File

@@ -4,63 +4,61 @@
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/fp8.h"
#include "mlx/backend/metal/kernels/fp4.h"
#include "mlx/backend/metal/kernels/fp_quantized.h" #include "mlx/backend/metal/kernels/fp_quantized.h"
#define instantiate_quantized(name, type) \ #define instantiate_quantized(mode, name, type) \
instantiate_kernel( \ instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \ #mode "_" #name "_" #type "_gs_32_b_4", \
name, \ fp_ ## name, \
type, \ type, \
32, \ 32, \
uint8_t) 4)
#define instantiate_quantized_batched(name, type, batched) \ #define instantiate_quantized_batched(mode, name, type, batched) \
instantiate_kernel( \ instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \ #mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \ fp_ ## name, \
type, \ type, \
32, \ 32, \
uint8_t, \ 4, \
batched) batched)
#define instantiate_quantized_aligned(name, type, aligned) \ #define instantiate_quantized_aligned(mode, name, type, aligned) \
instantiate_kernel( \ instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \ #mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \ fp_ ## name, \
type, \ type, \
32, \ 32, \
uint8_t, \ 4, \
aligned) aligned)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \ #define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
instantiate_kernel( \ instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \ #mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \ fp_ ## name, \
type, \ type, \
32, \ 32, \
uint8_t, \ 4, \
aligned, \ aligned, \
batched) batched)
#define instantiate_quantized_quad(name, type, D, batched) \ #define instantiate_quantized_quad(mode, name, type, D, batched) \
instantiate_kernel( \ instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \ #mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \ fp_ ## name, \
type, \ type, \
32, \ 32, \
uint8_t, \ 4, \
D, \ D, \
batched) batched)
#define instantiate_quantized_split_k(name, type, split_k) \ #define instantiate_quantized_split_k(mode, name, type, split_k) \
instantiate_kernel( \ instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \ #mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \ fp_ ## name, \
type, \ type, \
32, \ 32, \
uint8_t, \ 4, \
split_k) split_k)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ #define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
@@ -69,7 +67,7 @@
func, \ func, \
type, \ type, \
32, \ 32, \
uint8_t, \ 4, \
bm, \ bm, \
bn, \ bn, \
bk, \ bk, \
@@ -77,62 +75,62 @@
wn, \ wn, \
transpose) transpose)
#define instantiate_quantized_batched_wrap(name, type) \ #define instantiate_quantized_batched_wrap(mode, name, type) \
instantiate_quantized_batched(name, type, 1) \ instantiate_quantized_batched(mode, name, type, 1) \
instantiate_quantized_batched(name, type, 0) instantiate_quantized_batched(mode, name, type, 0)
#define instantiate_quantized_all_batched(type) \ #define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \ instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \ instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \ instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type) instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
#define instantiate_quantized_all_single(type) \ #define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \ instantiate_quantized(mxfp4, gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \ instantiate_quantized(mxfp4, gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \ instantiate_quantized(mxfp4, gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type) instantiate_quantized(mxfp4, gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \ #define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \ instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \ instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \ instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \ instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \ instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0) instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \ #define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \ instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \ instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \ instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0) instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \ #define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \ instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32) instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \ #define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \ #define instantiate_quantize_dequantize(type, mode, group_size, bits) \
instantiate_kernel( \ instantiate_kernel( \
mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \ #mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_quantize, \ fp_quantize, \
type, \ type, \
group_size, \ group_size, \
bits) \ bits) \
instantiate_kernel( \ instantiate_kernel( \
mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \ #mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
fp_dequantize, \ fp_dequantize, \
type, \ type, \
group_size, \ group_size, \
bits) bits)
#define instantiate_quantize_dequantize_modes(type) \ #define instantiate_quantize_dequantize_modes(type) \
instantiate_quantize_dequantize(type, "mxfp4", 32, 4) \ instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
instantiate_quantize_dequantize(type, "nvfp4", 16, 4) \ instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
instantiate_quantize_dequantize(type, "mxfp8", 32, 8) instantiate_quantize_dequantize(type, mxfp8, 32, 8)
#define instantiate_quantized_types(type) \ #define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \ instantiate_quantized_all_batched(type) \

View File

@@ -27,14 +27,9 @@ auto get_quantized_kernel_wrapped(
int bits, int bits,
Args... args) { Args... args) {
std::string template_def; std::string template_def;
auto fname = mode + "_" + func; std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
if (mode == "affine") {
template_def = get_template_definition( template_def = get_template_definition(
name, fname, type, group_size, bits, std::forward<Args>(args)...); name, fname, type, group_size, bits, std::forward<Args>(args)...);
} else {
template_def = get_template_definition(
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
}
return get_quantized_kernel(d, name, template_def, mode); return get_quantized_kernel(d, name, template_def, mode);
} }