This commit is contained in:
Awni Hannun 2025-08-20 14:05:35 -07:00
parent 6295e53216
commit 51449428dd
2 changed files with 157 additions and 67 deletions

View File

@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2025 Apple Inc.
#include <metal_simdgroup>
#include <metal_stdlib>
@ -24,6 +24,17 @@ inline constexpr short get_bytes_per_pack() {
return wsize / 8;
}
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
template <typename T, typename U, int values_per_thread>
inline void load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) {
@ -48,7 +59,7 @@ inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
}
}
constant float MXFP4_LUT[16] = {
constexpr constant static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
@ -66,51 +77,74 @@ constant float MXFP4_LUT[16] = {
-4.0f,
-6.0f};
template <typename U, int values_per_thread>
inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) {
U accum = 0;
template <typename T>
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <typename U, int values_per_thread>
inline U qdot(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum +=
(x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] +
x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] +
x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] +
x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]);
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]);
}
return scale * accum;
}
template <typename U, int values_per_thread, typename S>
inline U
qdot_safe(const device uint8_t* w, const thread U* x_thread, S scale, int N) {
template <typename U, int values_per_thread>
inline U qdot_safe(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut,
int N) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] +
x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] +
x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] +
x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]);
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] +
x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] +
x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]);
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) {
inline void qouter(
const thread uint8_t* w,
U x,
U scale,
thread U* result,
const threadgroup U* lut) {
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * scale * MXFP4_LUT[w[i] & 0x0f];
result[2 * i + 1] += x * scale * MXFP4_LUT[(w[i] & 0xf0) >> 4];
result[2 * i] += x * scale * lut[w[i] & 0x0f];
result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4];
}
}
template <typename U, int N>
inline void
dequantize(const device uint8_t* w, U scale, threadgroup U* w_local) {
inline void dequantize(
const device uint8_t* w,
U scale,
threadgroup U* w_local,
const threadgroup U* lut) {
for (int i = 0; i < (N / 2); i++) {
w_local[2 * i] = scale * static_cast<U>(MXFP4_LUT[w[i] & 0x0f]);
w_local[2 * i + 1] = scale * static_cast<U>(MXFP4_LUT[(w[i] & 0xf0) >> 4]);
w_local[2 * i] = scale * lut[w[i] & 0x0f];
w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4];
}
}
@ -150,12 +184,14 @@ struct QuantizedBlockLoader {
threadgroup T* dst;
const device uint8_t* src;
const device S* scales;
threadgroup T* lut;
QuantizedBlockLoader(
const device uint8_t* src_,
const device S* scales_,
const int src_ld_,
threadgroup T* dst_,
threadgroup T* lut_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
@ -170,17 +206,20 @@ struct QuantizedBlockLoader {
dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size) {}
scales(scales_ + bi * src_ld / group_size),
lut(lut_) {
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
}
void load_unsafe() const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
T scale = metal::pow(T(2.0), static_cast<int>(*scales) - 127);
T scale = dequantize_scale<T>(*scales);
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor>(
src + i * bytes_per_pack, scale, dst + i * pack_factor);
src + i * bytes_per_pack, scale, dst + i * pack_factor, lut);
}
}
@ -203,12 +242,13 @@ struct QuantizedBlockLoader {
return;
}
T scale = metal::pow(T(2.0), static_cast<int>(*scales) - 127);
T scale = dequantize_scale<T>(*scales);
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor>(
(device uint8_t*)(src + i * bytes_per_pack),
scale,
dst + i * pack_factor);
dst + i * pack_factor,
lut);
}
}
@ -240,7 +280,10 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 8;
constexpr int values_per_thread = D / QUAD_SIZE;
@ -252,6 +295,7 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
@ -269,9 +313,9 @@ METAL_FUNC void mxfp4_qmv_quad_impl(
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
}
@ -293,7 +337,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
@ -306,9 +351,9 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@ -328,8 +373,8 @@ METAL_FUNC void mxfp4_qmv_fast_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
ws += block_size * bytes_per_pack / pack_factor;
@ -355,7 +400,8 @@ METAL_FUNC void mxfp4_qmv_impl(
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
@ -372,6 +418,7 @@ METAL_FUNC void mxfp4_qmv_impl(
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
@ -402,7 +449,7 @@ METAL_FUNC void mxfp4_qmv_impl(
const device auto* sl = scales + row * in_vec_size_g;
S s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
ws += block_size * bytes_per_pack / pack_factor;
@ -420,8 +467,8 @@ METAL_FUNC void mxfp4_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
}
@ -449,8 +496,8 @@ METAL_FUNC void mxfp4_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s);
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
ws += block_size * bytes_per_pack / pack_factor;
@ -468,9 +515,9 @@ METAL_FUNC void mxfp4_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = metal::pow(2.0f, static_cast<int>(sl[0]) - 127);
U s = dequantize_scale<U>(sl[0]);
result[row] +=
qdot_safe<U, values_per_thread>(wl, x_thread, s, remaining);
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
}
}
for (int row = 0; row < results_per_simdgroup; row++) {
@ -492,7 +539,8 @@ METAL_FUNC void mxfp4_qvm_impl(
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int num_simdgroups = 2;
constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack();
@ -513,6 +561,8 @@ METAL_FUNC void mxfp4_qvm_impl(
thread U scale = 0;
thread U x_local = 0;
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
@ -531,10 +581,10 @@ METAL_FUNC void mxfp4_qvm_impl(
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += block_size) {
x_local = *x;
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result);
(thread uint8_t*)&w_local, x_local, scale, result, lut);
x += block_size;
scales += block_size * out_vec_size_g;
@ -543,11 +593,11 @@ METAL_FUNC void mxfp4_qvm_impl(
} else {
for (int i = block_size; i < in_vec_size; i += block_size) {
x_local = *x;
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result);
(thread uint8_t*)&w_local, x_local, scale, result, lut);
x += block_size;
scales += block_size * out_vec_size_g;
@ -555,14 +605,14 @@ METAL_FUNC void mxfp4_qvm_impl(
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = metal::pow(2.0f, static_cast<int>(*scales) - 127);
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
} else {
x_local = 0;
scale = 0;
}
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result);
(thread uint8_t*)&w_local, x_local, scale, result, lut);
}
// Accumulate in the simdgroup
@ -601,7 +651,8 @@ METAL_FUNC void mxfp4_qmm_t_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup T* lut) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
@ -646,7 +697,7 @@ METAL_FUNC void mxfp4_qmm_t_impl(
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
@ -725,7 +776,8 @@ METAL_FUNC void mxfp4_qmm_n_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup T* lut) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
@ -767,7 +819,7 @@ METAL_FUNC void mxfp4_qmm_n_impl(
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
@ -941,7 +993,9 @@ template <typename T, int group_size, int D, bool batched, typename S>
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
@ -959,8 +1013,20 @@ template <typename T, int group_size, int D, bool batched, typename S>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, D>(
w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid);
w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid,
simd_gid,
simd_lid,
lut);
}
template <typename T, int group_size, bool batched, typename S>
@ -998,8 +1064,9 @@ template <typename T, int group_size, bool batched, typename S>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, bool batched, typename S>
@ -1037,8 +1104,9 @@ template <typename T, const int group_size, bool batched, typename S>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, bool batched, typename S>
@ -1076,8 +1144,9 @@ template <typename T, const int group_size, bool batched, typename S>
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, int split_k = 32, typename S>
@ -1119,8 +1188,18 @@ template <typename T, const int group_size, int split_k = 32, typename S>
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid);
w,
scales,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid,
lut);
}
template <
@ -1157,6 +1236,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
threadgroup T lut[16];
if (batched) {
adjust_matrix_offsets(
@ -1175,7 +1255,7 @@ template <
tid);
}
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
@ -1212,6 +1292,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
threadgroup T lut[16];
if (batched) {
adjust_matrix_offsets(
@ -1231,7 +1312,7 @@ template <
}
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
@ -1279,8 +1360,9 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
@ -1328,8 +1410,9 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
@ -1377,8 +1460,9 @@ template <typename T, int group_size, typename S>
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <
@ -1420,6 +1504,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
threadgroup T lut[16];
adjust_matrix_offsets(
x,
@ -1442,7 +1527,7 @@ template <
s_strides,
tid);
mxfp4_qmm_t_impl<T, group_size, aligned_N, S, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
@ -1484,6 +1569,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
threadgroup T lut[16];
adjust_matrix_offsets(
x,
@ -1506,7 +1592,7 @@ template <
s_strides,
tid);
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
@ -1621,6 +1707,7 @@ template <
constexpr int bytes_per_pack = get_bytes_per_pack();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T lut[16];
using mma_t = mlx::steel::BlockMMA<
T,
@ -1709,6 +1796,7 @@ template <
scales + index * stride_s,
transpose ? K : N,
Ws,
lut,
simd_group_id,
simd_lane_id);

View File

@ -734,6 +734,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4":
group_size = 32
else:
group_size = 64
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)