mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix jit
This commit is contained in:
@@ -81,7 +81,8 @@ if(MLX_METAL_JIT)
|
|||||||
|
|
||||||
make_jit_source(quantized_utils)
|
make_jit_source(quantized_utils)
|
||||||
make_jit_source(quantized kernels/quantized_utils.h)
|
make_jit_source(quantized kernels/quantized_utils.h)
|
||||||
make_jit_source(fp4_quantized kernels/quantized_utils.h)
|
make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h
|
||||||
|
kernels/fp4.h)
|
||||||
make_jit_source(gemv_masked)
|
make_jit_source(gemv_masked)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ const char* hadamard();
|
|||||||
const char* logsumexp();
|
const char* logsumexp();
|
||||||
const char* quantized_utils();
|
const char* quantized_utils();
|
||||||
const char* quantized();
|
const char* quantized();
|
||||||
const char* fp4_quantized();
|
const char* fp_quantized();
|
||||||
const char* ternary();
|
const char* ternary();
|
||||||
const char* scan();
|
const char* scan();
|
||||||
const char* scatter_axis();
|
const char* scatter_axis();
|
||||||
|
|||||||
@@ -829,7 +829,7 @@ MTL::ComputePipelineState* get_quantized_kernel(
|
|||||||
metal::utils(),
|
metal::utils(),
|
||||||
metal::gemm(),
|
metal::gemm(),
|
||||||
metal::quantized_utils(),
|
metal::quantized_utils(),
|
||||||
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
|
(mode == "affine") ? metal::quantized() : metal::fp_quantized(),
|
||||||
template_def);
|
template_def);
|
||||||
return kernel_source;
|
return kernel_source;
|
||||||
});
|
});
|
||||||
@@ -856,39 +856,22 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
std::string kernel_source;
|
std::string kernel_source;
|
||||||
concatenate(
|
concatenate(
|
||||||
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
||||||
if (mode == "affine") {
|
bool is_affine = mode == "affine";
|
||||||
concatenate(
|
concatenate(
|
||||||
kernel_source,
|
kernel_source,
|
||||||
metal::quantized(),
|
is_affine ? metal::quantized() : metal::fp_quantized(),
|
||||||
get_template_definition(
|
get_template_definition(
|
||||||
lib_name,
|
lib_name,
|
||||||
mode + "_gather_qmm_rhs",
|
(is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs"),
|
||||||
get_type_string(x.dtype()),
|
get_type_string(x.dtype()),
|
||||||
group_size,
|
group_size,
|
||||||
bits,
|
bits,
|
||||||
bm,
|
bm,
|
||||||
bn,
|
bn,
|
||||||
bk,
|
bk,
|
||||||
wm,
|
wm,
|
||||||
wn,
|
wn,
|
||||||
transpose));
|
transpose));
|
||||||
} else {
|
|
||||||
concatenate(
|
|
||||||
kernel_source,
|
|
||||||
metal::fp4_quantized(),
|
|
||||||
get_template_definition(
|
|
||||||
lib_name,
|
|
||||||
mode + "_gather_qmm_rhs",
|
|
||||||
get_type_string(x.dtype()),
|
|
||||||
group_size,
|
|
||||||
"uint8_t",
|
|
||||||
bm,
|
|
||||||
bn,
|
|
||||||
bk,
|
|
||||||
wm,
|
|
||||||
wn,
|
|
||||||
transpose));
|
|
||||||
}
|
|
||||||
return kernel_source;
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
|||||||
@@ -110,7 +110,8 @@ if(NOT MLX_METAL_JIT)
|
|||||||
reduction/reduce_col.h
|
reduction/reduce_col.h
|
||||||
reduction/reduce_row.h)
|
reduction/reduce_row.h)
|
||||||
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||||
build_kernel(fp_quantized fp_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
build_kernel(fp_quantized fp4.h fp_quantized.h quantized_utils.h
|
||||||
|
${STEEL_HEADERS})
|
||||||
build_kernel(scan scan.h)
|
build_kernel(scan scan.h)
|
||||||
build_kernel(softmax softmax.h)
|
build_kernel(softmax softmax.h)
|
||||||
build_kernel(logsumexp logsumexp.h)
|
build_kernel(logsumexp logsumexp.h)
|
||||||
|
|||||||
@@ -3,6 +3,9 @@
|
|||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/fp4.h"
|
||||||
|
#include "mlx/backend/metal/kernels/fp8.h"
|
||||||
|
|
||||||
constant bool align_M [[function_constant(200)]];
|
constant bool align_M [[function_constant(200)]];
|
||||||
constant bool align_N [[function_constant(201)]];
|
constant bool align_N [[function_constant(201)]];
|
||||||
constant bool align_K [[function_constant(202)]];
|
constant bool align_K [[function_constant(202)]];
|
||||||
@@ -60,7 +63,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
void load_fp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||||
if (simd_gid == 0 && simd_lid < 16) {
|
if (simd_gid == 0 && simd_lid < 16) {
|
||||||
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
|
lut[simd_lid] = static_cast<T>(FP4_LUT[simd_lid]);
|
||||||
}
|
}
|
||||||
@@ -137,8 +140,7 @@ template <
|
|||||||
short dst_ld,
|
short dst_ld,
|
||||||
short reduction_dim,
|
short reduction_dim,
|
||||||
short tgp_size,
|
short tgp_size,
|
||||||
short group_size,
|
short group_size>
|
||||||
typename S>
|
|
||||||
struct QuantizedBlockLoader {
|
struct QuantizedBlockLoader {
|
||||||
static_assert(
|
static_assert(
|
||||||
BCOLS <= group_size,
|
BCOLS <= group_size,
|
||||||
@@ -165,12 +167,12 @@ struct QuantizedBlockLoader {
|
|||||||
|
|
||||||
threadgroup T* dst;
|
threadgroup T* dst;
|
||||||
const device uint8_t* src;
|
const device uint8_t* src;
|
||||||
const device S* scales;
|
const device uint8_t* scales;
|
||||||
threadgroup T* lut;
|
threadgroup T* lut;
|
||||||
|
|
||||||
QuantizedBlockLoader(
|
QuantizedBlockLoader(
|
||||||
const device uint8_t* src_,
|
const device uint8_t* src_,
|
||||||
const device S* scales_,
|
const device uint8_t* scales_,
|
||||||
const int src_ld_,
|
const int src_ld_,
|
||||||
threadgroup T* dst_,
|
threadgroup T* dst_,
|
||||||
threadgroup T* lut_,
|
threadgroup T* lut_,
|
||||||
@@ -190,7 +192,7 @@ struct QuantizedBlockLoader {
|
|||||||
bj * bytes_per_pack),
|
bj * bytes_per_pack),
|
||||||
scales(scales_ + bi * src_ld / group_size),
|
scales(scales_ + bi * src_ld / group_size),
|
||||||
lut(lut_) {
|
lut(lut_) {
|
||||||
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
|
load_fp4_lut(lut, simd_group_id, simd_lane_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void load_unsafe() const {
|
void load_unsafe() const {
|
||||||
@@ -252,10 +254,10 @@ struct QuantizedBlockLoader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int group_size, typename S, int D>
|
template <typename T, int group_size, int bits, int D>
|
||||||
METAL_FUNC void mxfp4_qmv_quad_impl(
|
METAL_FUNC void fp_qmv_quad_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
constant int& in_vec_size,
|
constant int& in_vec_size,
|
||||||
@@ -277,7 +279,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
|||||||
|
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_quadgroup] = {0};
|
thread U result[results_per_quadgroup] = {0};
|
||||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
@@ -293,7 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
|||||||
|
|
||||||
for (int row = 0; row < results_per_quadgroup; row++) {
|
for (int row = 0; row < results_per_quadgroup; row++) {
|
||||||
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||||
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
|
const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||||
|
|
||||||
U s = dequantize_scale<U>(sl[0]);
|
U s = dequantize_scale<U>(sl[0]);
|
||||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||||
@@ -309,10 +311,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, int bits>
|
||||||
METAL_FUNC void mxfp4_qmv_fast_impl(
|
METAL_FUNC void fp_qmv_fast_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& in_vec_size,
|
const constant int& in_vec_size,
|
||||||
@@ -335,7 +337,7 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
|||||||
typedef float U;
|
typedef float U;
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_simdgroup] = {0};
|
thread U result[results_per_simdgroup] = {0};
|
||||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
@@ -372,10 +374,10 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, int bits>
|
||||||
METAL_FUNC void mxfp4_qmv_impl(
|
METAL_FUNC void fp_qmv_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& in_vec_size,
|
const constant int& in_vec_size,
|
||||||
@@ -400,7 +402,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
|
|
||||||
thread U x_thread[values_per_thread];
|
thread U x_thread[values_per_thread];
|
||||||
thread U result[results_per_simdgroup] = {0};
|
thread U result[results_per_simdgroup] = {0};
|
||||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
@@ -430,7 +432,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
const device auto* sl = scales + row * in_vec_size_g;
|
const device auto* sl = scales + row * in_vec_size_g;
|
||||||
|
|
||||||
S s = sl[0];
|
uint8_t s = sl[0];
|
||||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -511,10 +513,10 @@ METAL_FUNC void mxfp4_qmv_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, typename S>
|
template <typename T, const int group_size, int bits>
|
||||||
METAL_FUNC void mxfp4_qvm_impl(
|
METAL_FUNC void fp_qvm_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const int in_vec_size,
|
const int in_vec_size,
|
||||||
@@ -543,7 +545,7 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
thread U scale = 0;
|
thread U scale = 0;
|
||||||
thread U x_local = 0;
|
thread U x_local = 0;
|
||||||
|
|
||||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
load_fp4_lut(lut, simd_gid, simd_lid);
|
||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||||
@@ -615,14 +617,14 @@ METAL_FUNC void mxfp4_qvm_impl(
|
|||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
typename S,
|
const int bits,
|
||||||
const bool aligned_N,
|
const bool aligned_N,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
METAL_FUNC void mxfp4_qmm_t_impl(
|
METAL_FUNC void fp_qmm_t_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
threadgroup T* Xs,
|
threadgroup T* Xs,
|
||||||
@@ -659,8 +661,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
|||||||
BK_padded,
|
BK_padded,
|
||||||
1,
|
1,
|
||||||
WM * WN * SIMD_SIZE,
|
WM * WN * SIMD_SIZE,
|
||||||
group_size,
|
group_size>;
|
||||||
S>;
|
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int K_w = K * bytes_per_pack / pack_factor;
|
const int K_w = K * bytes_per_pack / pack_factor;
|
||||||
@@ -741,13 +742,13 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
|||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
typename S,
|
const int bits,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
METAL_FUNC void mxfp4_qmm_n_impl(
|
METAL_FUNC void fp_qmm_n_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
threadgroup T* Xs,
|
threadgroup T* Xs,
|
||||||
@@ -785,8 +786,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
|||||||
BN_padded,
|
BN_padded,
|
||||||
0,
|
0,
|
||||||
WM * WN * SIMD_SIZE,
|
WM * WN * SIMD_SIZE,
|
||||||
group_size,
|
group_size>;
|
||||||
S>;
|
|
||||||
|
|
||||||
auto wl = (const device uint8_t*)w;
|
auto wl = (const device uint8_t*)w;
|
||||||
|
|
||||||
@@ -873,11 +873,11 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename S>
|
template <typename T>
|
||||||
METAL_FUNC void adjust_matrix_offsets(
|
METAL_FUNC void adjust_matrix_offsets(
|
||||||
const device T*& x,
|
const device T*& x,
|
||||||
const device uint32_t*& w,
|
const device uint32_t*& w,
|
||||||
const device S*& scales,
|
const device uint8_t*& scales,
|
||||||
device T*& y,
|
device T*& y,
|
||||||
int output_stride,
|
int output_stride,
|
||||||
const constant int& x_batch_ndims,
|
const constant int& x_batch_ndims,
|
||||||
@@ -908,11 +908,11 @@ METAL_FUNC void adjust_matrix_offsets(
|
|||||||
y += tid.z * output_stride;
|
y += tid.z * output_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename S>
|
template <typename T>
|
||||||
METAL_FUNC void adjust_matrix_offsets(
|
METAL_FUNC void adjust_matrix_offsets(
|
||||||
const device T*& x,
|
const device T*& x,
|
||||||
const device uint32_t*& w,
|
const device uint32_t*& w,
|
||||||
const device S*& scales,
|
const device uint8_t*& scales,
|
||||||
const device uint32_t* lhs_indices,
|
const device uint32_t* lhs_indices,
|
||||||
const device uint32_t* rhs_indices,
|
const device uint32_t* rhs_indices,
|
||||||
device T*& y,
|
device T*& y,
|
||||||
@@ -958,10 +958,10 @@ METAL_FUNC void adjust_matrix_offsets(
|
|||||||
y += tid.z * output_stride;
|
y += tid.z * output_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S, int D, bool batched>
|
template <typename T, int group_size, int bits, int D, bool batched>
|
||||||
[[kernel]] void mxfp4_qmv_quad(
|
[[kernel]] void fp_qmv_quad(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& in_vec_size,
|
const constant int& in_vec_size,
|
||||||
@@ -996,7 +996,7 @@ template <typename T, int group_size, typename S, int D, bool batched>
|
|||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_quad_impl<T, group_size, S, D>(
|
fp_qmv_quad_impl<T, group_size, bits, D>(
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
x,
|
x,
|
||||||
@@ -1011,10 +1011,10 @@ template <typename T, int group_size, typename S, int D, bool batched>
|
|||||||
lut);
|
lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S, bool batched>
|
template <typename T, int group_size, int bits, bool batched>
|
||||||
[[kernel]] void mxfp4_qmv_fast(
|
[[kernel]] void fp_qmv_fast(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& in_vec_size,
|
const constant int& in_vec_size,
|
||||||
@@ -1047,14 +1047,14 @@ template <typename T, int group_size, typename S, bool batched>
|
|||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_fast_impl<T, group_size>(
|
fp_qmv_fast_impl<T, group_size, bits>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, typename S, bool batched>
|
template <typename T, const int group_size, int bits, bool batched>
|
||||||
[[kernel]] void mxfp4_qmv(
|
[[kernel]] void fp_qmv(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& in_vec_size,
|
const constant int& in_vec_size,
|
||||||
@@ -1087,14 +1087,14 @@ template <typename T, const int group_size, typename S, bool batched>
|
|||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_impl<T, group_size>(
|
fp_qmv_impl<T, group_size, bits>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, typename S, bool batched>
|
template <typename T, const int group_size, int bits, bool batched>
|
||||||
[[kernel]] void mxfp4_qvm(
|
[[kernel]] void fp_qvm(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& in_vec_size,
|
const constant int& in_vec_size,
|
||||||
@@ -1127,14 +1127,14 @@ template <typename T, const int group_size, typename S, bool batched>
|
|||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qvm_impl<T, group_size>(
|
fp_qvm_impl<T, group_size, bits>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, typename S, int split_k = 32>
|
template <typename T, const int group_size, int bits, int split_k = 32>
|
||||||
[[kernel]] void mxfp4_qvm_split_k(
|
[[kernel]] void fp_qvm_split_k(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& in_vec_size,
|
const constant int& in_vec_size,
|
||||||
@@ -1171,7 +1171,7 @@ template <typename T, const int group_size, typename S, int split_k = 32>
|
|||||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||||
|
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qvm_impl<T, group_size>(
|
fp_qvm_impl<T, group_size, bits>(
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
x,
|
x,
|
||||||
@@ -1187,15 +1187,15 @@ template <typename T, const int group_size, typename S, int split_k = 32>
|
|||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
typename S,
|
const int bits,
|
||||||
const bool aligned_N,
|
const bool aligned_N,
|
||||||
const bool batched,
|
const bool batched,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void mxfp4_qmm_t(
|
[[kernel]] void fp_qmm_t(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& K,
|
const constant int& K,
|
||||||
@@ -1236,21 +1236,21 @@ template <
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
typename S,
|
const int bits,
|
||||||
const bool batched,
|
const bool batched,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void mxfp4_qmm_n(
|
[[kernel]] void fp_qmm_n(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& K,
|
const constant int& K,
|
||||||
@@ -1293,14 +1293,14 @@ template <
|
|||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
|
|
||||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void mxfp4_gather_qmv_fast(
|
[[kernel]] void fp_gather_qmv_fast(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* lhs_indices,
|
const device uint32_t* lhs_indices,
|
||||||
const device uint32_t* rhs_indices,
|
const device uint32_t* rhs_indices,
|
||||||
@@ -1343,14 +1343,14 @@ template <typename T, int group_size, typename S>
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_fast_impl<T, group_size>(
|
fp_qmv_fast_impl<T, group_size, bits>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void mxfp4_gather_qmv(
|
[[kernel]] void fp_gather_qmv(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* lhs_indices,
|
const device uint32_t* lhs_indices,
|
||||||
const device uint32_t* rhs_indices,
|
const device uint32_t* rhs_indices,
|
||||||
@@ -1393,14 +1393,14 @@ template <typename T, int group_size, typename S>
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qmv_impl<T, group_size>(
|
fp_qmv_impl<T, group_size, bits>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, typename S>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void mxfp4_gather_qvm(
|
[[kernel]] void fp_gather_qvm(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* lhs_indices,
|
const device uint32_t* lhs_indices,
|
||||||
const device uint32_t* rhs_indices,
|
const device uint32_t* rhs_indices,
|
||||||
@@ -1443,21 +1443,21 @@ template <typename T, int group_size, typename S>
|
|||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
threadgroup float lut[16];
|
threadgroup float lut[16];
|
||||||
mxfp4_qvm_impl<T, group_size>(
|
fp_qvm_impl<T, group_size, bits>(
|
||||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
typename S,
|
const int bits,
|
||||||
const bool aligned_N,
|
const bool aligned_N,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void mxfp4_gather_qmm_t(
|
[[kernel]] void fp_gather_qmm_t(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* lhs_indices,
|
const device uint32_t* lhs_indices,
|
||||||
const device uint32_t* rhs_indices,
|
const device uint32_t* rhs_indices,
|
||||||
@@ -1508,20 +1508,20 @@ template <
|
|||||||
w_strides,
|
w_strides,
|
||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
|
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
typename S,
|
const int bits,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
[[kernel]] void mxfp4_gather_qmm_n(
|
[[kernel]] void fp_gather_qmm_n(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* lhs_indices,
|
const device uint32_t* lhs_indices,
|
||||||
const device uint32_t* rhs_indices,
|
const device uint32_t* rhs_indices,
|
||||||
@@ -1573,24 +1573,24 @@ template <
|
|||||||
w_strides,
|
w_strides,
|
||||||
s_strides,
|
s_strides,
|
||||||
tid);
|
tid);
|
||||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
int group_size,
|
int group_size,
|
||||||
typename S,
|
int bits,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int BK,
|
int BK,
|
||||||
int WM,
|
int WM,
|
||||||
int WN,
|
int WN,
|
||||||
bool transpose>
|
bool transpose>
|
||||||
[[kernel]] void mxfp4_gather_qmm_rhs(
|
[[kernel]] void fp_gather_qmm_rhs(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device S* scales,
|
const device uint8_t* scales,
|
||||||
const device uint32_t* indices,
|
const device uint32_t* indices,
|
||||||
device T* y,
|
device T* y,
|
||||||
const constant int& M,
|
const constant int& M,
|
||||||
@@ -1626,8 +1626,7 @@ template <
|
|||||||
transpose ? BK_padded : BN_padded,
|
transpose ? BK_padded : BN_padded,
|
||||||
transpose,
|
transpose,
|
||||||
WM * WN * SIMD_SIZE,
|
WM * WN * SIMD_SIZE,
|
||||||
group_size,
|
group_size>;
|
||||||
S>;
|
|
||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
||||||
@@ -1775,7 +1774,7 @@ template <
|
|||||||
template <int bits>
|
template <int bits>
|
||||||
struct Quantize {
|
struct Quantize {
|
||||||
uint8_t operator()(float x) {
|
uint8_t operator()(float x) {
|
||||||
if constexpr (bits == 8) {
|
if (bits == 8) {
|
||||||
return fp8_e4m3(x).bits;
|
return fp8_e4m3(x).bits;
|
||||||
} else {
|
} else {
|
||||||
return fp4_e2m1(x).bits;
|
return fp4_e2m1(x).bits;
|
||||||
@@ -1786,7 +1785,7 @@ struct Quantize {
|
|||||||
template <int bits>
|
template <int bits>
|
||||||
struct Dequantize {
|
struct Dequantize {
|
||||||
float operator()(uint8_t x) {
|
float operator()(uint8_t x) {
|
||||||
if constexpr (bits == 8) {
|
if (bits == 8) {
|
||||||
return float(*(thread fp8_e4m3*)(&x));
|
return float(*(thread fp8_e4m3*)(&x));
|
||||||
} else {
|
} else {
|
||||||
return float(*(thread fp4_e2m1*)(&x));
|
return float(*(thread fp4_e2m1*)(&x));
|
||||||
|
|||||||
@@ -4,63 +4,61 @@
|
|||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||||
#include "mlx/backend/metal/kernels/fp8.h"
|
|
||||||
#include "mlx/backend/metal/kernels/fp4.h"
|
|
||||||
#include "mlx/backend/metal/kernels/fp_quantized.h"
|
#include "mlx/backend/metal/kernels/fp_quantized.h"
|
||||||
|
|
||||||
#define instantiate_quantized(name, type) \
|
#define instantiate_quantized(mode, name, type) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
#name "_" #type "_gs_32_b_4", \
|
#mode "_" #name "_" #type "_gs_32_b_4", \
|
||||||
name, \
|
fp_ ## name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
uint8_t)
|
4)
|
||||||
|
|
||||||
#define instantiate_quantized_batched(name, type, batched) \
|
#define instantiate_quantized_batched(mode, name, type, batched) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
#name "_" #type "_gs_32_b_4_batch_" #batched, \
|
#mode "_" #name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||||
name, \
|
fp_ ## name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
uint8_t, \
|
4, \
|
||||||
batched)
|
batched)
|
||||||
|
|
||||||
#define instantiate_quantized_aligned(name, type, aligned) \
|
#define instantiate_quantized_aligned(mode, name, type, aligned) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||||
name, \
|
fp_ ## name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
uint8_t, \
|
4, \
|
||||||
aligned)
|
aligned)
|
||||||
|
|
||||||
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
#define instantiate_quantized_aligned_batched(mode, name, type, aligned, batched) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
#mode "_" #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||||
name, \
|
fp_ ## name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
uint8_t, \
|
4, \
|
||||||
aligned, \
|
aligned, \
|
||||||
batched)
|
batched)
|
||||||
|
|
||||||
#define instantiate_quantized_quad(name, type, D, batched) \
|
#define instantiate_quantized_quad(mode, name, type, D, batched) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
#mode "_" #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||||
name, \
|
fp_ ## name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
uint8_t, \
|
4, \
|
||||||
D, \
|
D, \
|
||||||
batched)
|
batched)
|
||||||
|
|
||||||
#define instantiate_quantized_split_k(name, type, split_k) \
|
#define instantiate_quantized_split_k(mode, name, type, split_k) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
#mode "_" #name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||||
name, \
|
fp_ ## name, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
uint8_t, \
|
4, \
|
||||||
split_k)
|
split_k)
|
||||||
|
|
||||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||||
@@ -69,7 +67,7 @@
|
|||||||
func, \
|
func, \
|
||||||
type, \
|
type, \
|
||||||
32, \
|
32, \
|
||||||
uint8_t, \
|
4, \
|
||||||
bm, \
|
bm, \
|
||||||
bn, \
|
bn, \
|
||||||
bk, \
|
bk, \
|
||||||
@@ -77,62 +75,62 @@
|
|||||||
wn, \
|
wn, \
|
||||||
transpose)
|
transpose)
|
||||||
|
|
||||||
#define instantiate_quantized_batched_wrap(name, type) \
|
#define instantiate_quantized_batched_wrap(mode, name, type) \
|
||||||
instantiate_quantized_batched(name, type, 1) \
|
instantiate_quantized_batched(mode, name, type, 1) \
|
||||||
instantiate_quantized_batched(name, type, 0)
|
instantiate_quantized_batched(mode, name, type, 0)
|
||||||
|
|
||||||
#define instantiate_quantized_all_batched(type) \
|
#define instantiate_quantized_all_batched(type) \
|
||||||
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
|
instantiate_quantized_batched_wrap(mxfp4, qmv_fast, type) \
|
||||||
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
|
instantiate_quantized_batched_wrap(mxfp4, qmv, type) \
|
||||||
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
|
instantiate_quantized_batched_wrap(mxfp4, qvm, type) \
|
||||||
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
|
instantiate_quantized_batched_wrap(mxfp4, qmm_n, type)
|
||||||
|
|
||||||
#define instantiate_quantized_all_single(type) \
|
#define instantiate_quantized_all_single(type) \
|
||||||
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
|
instantiate_quantized(mxfp4, gather_qmv_fast, type) \
|
||||||
instantiate_quantized(mxfp4_gather_qmv, type) \
|
instantiate_quantized(mxfp4, gather_qmv, type) \
|
||||||
instantiate_quantized(mxfp4_gather_qvm, type) \
|
instantiate_quantized(mxfp4, gather_qvm, type) \
|
||||||
instantiate_quantized(mxfp4_gather_qmm_n, type)
|
instantiate_quantized(mxfp4, gather_qmm_n, type)
|
||||||
|
|
||||||
#define instantiate_quantized_all_aligned(type) \
|
#define instantiate_quantized_all_aligned(type) \
|
||||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
|
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, true) \
|
||||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
|
instantiate_quantized_aligned(mxfp4, gather_qmm_t, type, false) \
|
||||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
|
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 1) \
|
||||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
|
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, true, 0) \
|
||||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
|
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 1) \
|
||||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
|
instantiate_quantized_aligned_batched(mxfp4, qmm_t, type, false, 0)
|
||||||
|
|
||||||
#define instantiate_quantized_all_quad(type) \
|
#define instantiate_quantized_all_quad(type) \
|
||||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
|
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 1) \
|
||||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
|
instantiate_quantized_quad(mxfp4, qmv_quad, type, 64, 0) \
|
||||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
|
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 1) \
|
||||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
|
instantiate_quantized_quad(mxfp4, qmv_quad, type, 128, 0)
|
||||||
|
|
||||||
#define instantiate_quantized_all_splitk(type) \
|
#define instantiate_quantized_all_splitk(type) \
|
||||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
|
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 8) \
|
||||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
|
instantiate_quantized_split_k(mxfp4, qvm_split_k, type, 32)
|
||||||
|
|
||||||
#define instantiate_quantized_all_rhs(type) \
|
#define instantiate_quantized_all_rhs(type) \
|
||||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||||
|
|
||||||
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
|
#define instantiate_quantize_dequantize(type, mode, group_size, bits) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
|
#mode "_quantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||||
fp_quantize, \
|
fp_quantize, \
|
||||||
type, \
|
type, \
|
||||||
group_size, \
|
group_size, \
|
||||||
bits) \
|
bits) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
|
#mode "_dequantize_" #type "_gs_" #group_size "_b_" #bits, \
|
||||||
fp_dequantize, \
|
fp_dequantize, \
|
||||||
type, \
|
type, \
|
||||||
group_size, \
|
group_size, \
|
||||||
bits)
|
bits)
|
||||||
|
|
||||||
#define instantiate_quantize_dequantize_modes(type) \
|
#define instantiate_quantize_dequantize_modes(type) \
|
||||||
instantiate_quantize_dequantize(type, "mxfp4", 32, 4) \
|
instantiate_quantize_dequantize(type, mxfp4, 32, 4) \
|
||||||
instantiate_quantize_dequantize(type, "nvfp4", 16, 4) \
|
instantiate_quantize_dequantize(type, nvfp4, 16, 4) \
|
||||||
instantiate_quantize_dequantize(type, "mxfp8", 32, 8)
|
instantiate_quantize_dequantize(type, mxfp8, 32, 8)
|
||||||
|
|
||||||
#define instantiate_quantized_types(type) \
|
#define instantiate_quantized_types(type) \
|
||||||
instantiate_quantized_all_batched(type) \
|
instantiate_quantized_all_batched(type) \
|
||||||
|
|||||||
@@ -27,14 +27,9 @@ auto get_quantized_kernel_wrapped(
|
|||||||
int bits,
|
int bits,
|
||||||
Args... args) {
|
Args... args) {
|
||||||
std::string template_def;
|
std::string template_def;
|
||||||
auto fname = mode + "_" + func;
|
std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func;
|
||||||
if (mode == "affine") {
|
template_def = get_template_definition(
|
||||||
template_def = get_template_definition(
|
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
||||||
name, fname, type, group_size, bits, std::forward<Args>(args)...);
|
|
||||||
} else {
|
|
||||||
template_def = get_template_definition(
|
|
||||||
name, fname, type, group_size, "uint8_t", std::forward<Args>(args)...);
|
|
||||||
}
|
|
||||||
return get_quantized_kernel(d, name, template_def, mode);
|
return get_quantized_kernel(d, name, template_def, mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user