mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
speedup
This commit is contained in:
parent
6295e53216
commit
51449428dd
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
@ -24,6 +24,17 @@ inline constexpr short get_bytes_per_pack() {
|
||||
return wsize / 8;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline T dequantize_scale(uint8_t s) {
|
||||
using FOrI = union {
|
||||
bfloat16_t f;
|
||||
uint16_t i;
|
||||
};
|
||||
FOrI out;
|
||||
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
|
||||
return static_cast<T>(out.f);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int values_per_thread>
|
||||
inline void load_vector(const device T* x, thread U* x_thread) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
@ -48,7 +59,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
}
|
||||
}
|
||||
|
||||
constant float MXFP4_LUT[16] = {
|
||||
constexpr constant static float MXFP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
@ -66,51 +77,74 @@ constant float MXFP4_LUT[16] = {
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
template <typename U, int values_per_thread>
|
||||
inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
|
||||
U accum = 0;
|
||||
template <typename T>
|
||||
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
|
||||
if (simd_gid == 0 && simd_lid < 16) {
|
||||
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread>
|
||||
inline U qdot(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
const threadgroup U* lut) {
|
||||
U accum = 0;
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] +
|
||||
x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] +
|
||||
x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] +
|
||||
x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]);
|
||||
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]);
|
||||
}
|
||||
return scale * accum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, typename S>
|
||||
inline U
|
||||
qdot_safe(const device uint8_t* w, const thread U* x_thread, S scale, int N) {
|
||||
template <typename U, int values_per_thread>
|
||||
inline U qdot_safe(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
const threadgroup U* lut,
|
||||
int N) {
|
||||
U accum = 0;
|
||||
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] +
|
||||
x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] +
|
||||
x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] +
|
||||
x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]);
|
||||
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]);
|
||||
}
|
||||
return scale * accum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread>
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
|
||||
inline void qouter(
|
||||
const thread uint8_t* w,
|
||||
U x,
|
||||
U scale,
|
||||
thread U* result,
|
||||
const threadgroup U* lut) {
|
||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||
result[2 * i] += x * scale * MXFP4_LUT[w[i] & 0x0f];
|
||||
result[2 * i + 1] += x * scale * MXFP4_LUT[(w[i] & 0xf0) >> 4];
|
||||
result[2 * i] += x * scale * lut[w[i] & 0x0f];
|
||||
result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N>
|
||||
inline void
|
||||
dequantize(const device uint8_t* w, U scale, threadgroup U* w_local) {
|
||||
inline void dequantize(
|
||||
const device uint8_t* w,
|
||||
U scale,
|
||||
threadgroup U* w_local,
|
||||
const threadgroup U* lut) {
|
||||
for (int i = 0; i < (N / 2); i++) {
|
||||
w_local[2 * i] = scale * static_cast<U>(MXFP4_LUT[w[i] & 0x0f]);
|
||||
w_local[2 * i + 1] = scale * static_cast<U>(MXFP4_LUT[(w[i] & 0xf0) >> 4]);
|
||||
w_local[2 * i] = scale * lut[w[i] & 0x0f];
|
||||
w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4];
|
||||
}
|
||||
}
|
||||
|
||||
@ -150,12 +184,14 @@ struct QuantizedBlockLoader {
|
||||
threadgroup T* dst;
|
||||
const device uint8_t* src;
|
||||
const device S* scales;
|
||||
threadgroup T* lut;
|
||||
|
||||
QuantizedBlockLoader(
|
||||
const device uint8_t* src_,
|
||||
const device S* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
threadgroup T* lut_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
@ -170,17 +206,20 @@ struct QuantizedBlockLoader {
|
||||
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
||||
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
||||
bj * bytes_per_pack),
|
||||
scales(scales_ + bi * src_ld / group_size) {}
|
||||
scales(scales_ + bi * src_ld / group_size),
|
||||
lut(lut_) {
|
||||
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
|
||||
}
|
||||
|
||||
void load_unsafe() const {
|
||||
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
T scale = metal::pow(T(2.0), static_cast<int>(*scales) - 127);
|
||||
T scale = dequantize_scale<T>(*scales);
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor>(
|
||||
src + i * bytes_per_pack, scale, dst + i * pack_factor);
|
||||
src + i * bytes_per_pack, scale, dst + i * pack_factor, lut);
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,12 +242,13 @@ struct QuantizedBlockLoader {
|
||||
return;
|
||||
}
|
||||
|
||||
T scale = metal::pow(T(2.0), static_cast<int>(*scales) - 127);
|
||||
T scale = dequantize_scale<T>(*scales);
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor>(
|
||||
(device uint8_t*)(src + i * bytes_per_pack),
|
||||
scale,
|
||||
dst + i * pack_factor);
|
||||
dst + i * pack_factor,
|
||||
lut);
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,7 +280,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
uint quad_lid [[thread_index_in_quadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
|
||||
constexpr int pack_factor = 8;
|
||||
constexpr int values_per_thread = D / QUAD_SIZE;
|
||||
@ -252,6 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_quadgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
@ -269,9 +313,9 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
|
||||
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
||||
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
|
||||
|
||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
if (row * quads_per_simd + out_row < out_vec_size) {
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
}
|
||||
|
||||
@ -293,7 +337,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
@ -306,9 +351,9 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@ -328,8 +373,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
@ -355,7 +400,8 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int packs_per_thread = 1;
|
||||
@ -372,6 +418,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
thread U result[results_per_simdgroup] = {0};
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
@ -402,7 +449,7 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
S s = sl[0];
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
@ -420,8 +467,8 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
}
|
||||
|
||||
@ -449,8 +496,8 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
|
||||
}
|
||||
|
||||
ws += block_size * bytes_per_pack / pack_factor;
|
||||
@ -468,9 +515,9 @@ METAL_FUNC void mxfp4_qmv_impl(
|
||||
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||
const device auto* sl = scales + row * in_vec_size_g;
|
||||
|
||||
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
|
||||
U s = dequantize_scale<U>(sl[0]);
|
||||
result[row] +=
|
||||
qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
|
||||
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
|
||||
}
|
||||
}
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
@ -492,7 +539,8 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
const int out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup float* lut) {
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = get_pack_factor<32>();
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack();
|
||||
@ -513,6 +561,8 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
thread U scale = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
load_mxfp4_lut(lut, simd_gid, simd_lid);
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
@ -531,10 +581,10 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
if (remaining == 0) {
|
||||
for (int i = 0; i < in_vec_size; i += block_size) {
|
||||
x_local = *x;
|
||||
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
|
||||
scale = dequantize_scale<U>(*scales);
|
||||
w_local = *((device vec_w*)ws);
|
||||
qouter<U, tn * pack_factor>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
||||
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||
|
||||
x += block_size;
|
||||
scales += block_size * out_vec_size_g;
|
||||
@ -543,11 +593,11 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
} else {
|
||||
for (int i = block_size; i < in_vec_size; i += block_size) {
|
||||
x_local = *x;
|
||||
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
|
||||
scale = dequantize_scale<U>(*scales);
|
||||
w_local = *((device vec_w*)ws);
|
||||
|
||||
qouter<U, tn * pack_factor>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
||||
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||
|
||||
x += block_size;
|
||||
scales += block_size * out_vec_size_g;
|
||||
@ -555,14 +605,14 @@ METAL_FUNC void mxfp4_qvm_impl(
|
||||
}
|
||||
if (static_cast<int>(simd_lid) < remaining) {
|
||||
x_local = *x;
|
||||
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
|
||||
scale = dequantize_scale<U>(*scales);
|
||||
w_local = *((device vec_w*)ws);
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
}
|
||||
qouter<U, tn * pack_factor>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, result);
|
||||
(thread uint8_t*)&w_local, x_local, scale, result, lut);
|
||||
}
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
@ -601,7 +651,8 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup T* lut) {
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
@ -646,7 +697,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
|
||||
const short num_els = min(BM, M - y_row);
|
||||
const short num_outs = min(BN, N - y_col);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);
|
||||
loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
if (num_els < BM) {
|
||||
@ -725,7 +776,8 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
uint simd_lid [[thread_index_in_simdgroup]],
|
||||
threadgroup T* lut) {
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
@ -767,7 +819,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid);
|
||||
loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
if (num_els < BM) {
|
||||
@ -941,7 +993,9 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
||||
const constant int64_t* s_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
uint quad_lid [[thread_index_in_quadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets(
|
||||
@ -959,8 +1013,20 @@ template <typename T, int group_size, int D, bool batched, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_quad_impl<T, group_size, D>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
quad_gid,
|
||||
quad_lid,
|
||||
simd_gid,
|
||||
simd_lid,
|
||||
lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, bool batched, typename S>
|
||||
@ -998,8 +1064,9 @@ template <typename T, int group_size, bool batched, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, bool batched, typename S>
|
||||
@ -1037,8 +1104,9 @@ template <typename T, const int group_size, bool batched, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, bool batched, typename S>
|
||||
@ -1076,8 +1144,9 @@ template <typename T, const int group_size, bool batched, typename S>
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, int split_k = 32, typename S>
|
||||
@ -1119,8 +1188,18 @@ template <typename T, const int group_size, int split_k = 32, typename S>
|
||||
int in_vec_size_adj =
|
||||
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
||||
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
|
||||
w,
|
||||
scales,
|
||||
x,
|
||||
y,
|
||||
in_vec_size_adj,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid,
|
||||
lut);
|
||||
}
|
||||
|
||||
template <
|
||||
@ -1157,6 +1236,7 @@ template <
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
threadgroup T lut[16];
|
||||
|
||||
if (batched) {
|
||||
adjust_matrix_offsets(
|
||||
@ -1175,7 +1255,7 @@ template <
|
||||
tid);
|
||||
}
|
||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
@ -1212,6 +1292,7 @@ template <
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
threadgroup T lut[16];
|
||||
|
||||
if (batched) {
|
||||
adjust_matrix_offsets(
|
||||
@ -1231,7 +1312,7 @@ template <
|
||||
}
|
||||
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
@ -1279,8 +1360,9 @@ template <typename T, int group_size, typename S>
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_fast_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
@ -1328,8 +1410,9 @@ template <typename T, int group_size, typename S>
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qmv_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, int group_size, typename S>
|
||||
@ -1377,8 +1460,9 @@ template <typename T, int group_size, typename S>
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
threadgroup float lut[16];
|
||||
mxfp4_qvm_impl<T, group_size>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
@ -1420,6 +1504,7 @@ template <
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
threadgroup T lut[16];
|
||||
|
||||
adjust_matrix_offsets(
|
||||
x,
|
||||
@ -1442,7 +1527,7 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <
|
||||
@ -1484,6 +1569,7 @@ template <
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
threadgroup T lut[16];
|
||||
|
||||
adjust_matrix_offsets(
|
||||
x,
|
||||
@ -1506,7 +1592,7 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
||||
}
|
||||
|
||||
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
||||
@ -1621,6 +1707,7 @@ template <
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack();
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||
threadgroup T lut[16];
|
||||
|
||||
using mma_t = mlx::steel::BlockMMA<
|
||||
T,
|
||||
@ -1709,6 +1796,7 @@ template <
|
||||
scales + index * stride_s,
|
||||
transpose ? K : N,
|
||||
Ws,
|
||||
lut,
|
||||
simd_group_id,
|
||||
simd_lane_id);
|
||||
|
||||
|
@ -734,6 +734,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
else:
|
||||
group_size = 64
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
|
Loading…
Reference in New Issue
Block a user