Files
mlx/mlx/backend/metal/kernels/fp4_quantized.h
2025-08-28 18:26:25 -07:00

1792 lines
52 KiB
C++

// Copyright © 2025 Apple Inc.
#include <metal_simdgroup>
#include <metal_stdlib>
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
MLX_MTL_CONST int SIMD_SIZE = 32;
MLX_MTL_CONST int QUAD_SIZE = 4;
template <int wsize = 8>
inline constexpr short get_pack_factor() {
return wsize / 4;
}
template <int wsize = 8>
inline constexpr short get_bytes_per_pack() {
return wsize / 8;
}
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
template <typename T, typename U, int values_per_thread>
inline void load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) {
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3];
}
}
template <typename T, typename U, int values_per_thread>
inline void load_vector_safe(const device T* x, thread U* x_thread, int N) {
for (int i = 0; i < N; i += 4) {
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3];
}
for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0;
}
}
constexpr constant static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) {
if (simd_gid == 0 && simd_lid < 16) {
lut[simd_lid] = static_cast<T>(MXFP4_LUT[simd_lid]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
template <typename U, int values_per_thread>
inline U qdot(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline U qdot_safe(
const device uint8_t* w,
const thread U* x_thread,
U scale,
const threadgroup U* lut,
int N) {
U accum = 0;
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
}
return scale * accum;
}
template <typename U, int values_per_thread>
inline void qouter(
const thread uint8_t* w,
U x,
U scale,
thread U* result,
const threadgroup U* lut) {
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * scale * lut[w[i] & 0xf];
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
}
}
template <typename U, int N>
inline void dequantize(
const device uint8_t* w,
U scale,
threadgroup U* w_local,
const threadgroup U* lut) {
for (int i = 0; i < (N / 2); i++) {
w_local[2 * i] = scale * lut[w[i] & 0xf];
w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf];
}
}
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
typename S>
struct QuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
MLX_MTL_CONST short pack_factor = get_pack_factor<8>();
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack();
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short n_reads =
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
threadgroup T* dst;
const device uint8_t* src;
const device S* scales;
threadgroup T* lut;
QuantizedBlockLoader(
const device uint8_t* src_,
const device S* scales_,
const int src_ld_,
threadgroup T* dst_,
threadgroup T* lut_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(
reduction_dim ? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_step_cnt(0),
group_stride(BROWS * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
dst(dst_ + bi * dst_ld + bj * pack_factor),
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
bj * bytes_per_pack),
scales(scales_ + bi * src_ld / group_size),
lut(lut_) {
load_mxfp4_lut(lut, simd_group_id, simd_lane_id);
}
void load_unsafe() const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
T scale = dequantize_scale<T>(*scales);
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor>(
src + i * bytes_per_pack, scale, dst + i * pack_factor, lut);
}
}
void load_safe(short2 src_tile_dim) const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bi >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bi >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
T scale = dequantize_scale<T>(*scales);
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor>(
(device uint8_t*)(src + i * bytes_per_pack),
scale,
dst + i * pack_factor,
lut);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales++;
}
} else {
scales++;
}
} else {
scales += group_stride;
}
}
};
template <typename T, int group_size, typename S, int D>
METAL_FUNC void mxfp4_qmv_quad_impl(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 8;
constexpr int values_per_thread = D / QUAD_SIZE;
constexpr int packs_per_thread = values_per_thread / pack_factor;
constexpr int scale_step_per_thread = group_size / values_per_thread;
constexpr int results_per_quadgroup = 8;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.x * in_vec_size + quad_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
load_vector<T, U, values_per_thread>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) {
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device S* sl = scales + row * in_vec_size_g * quads_per_simd;
U s = dequantize_scale<U>(sl[0]);
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
}
for (int row = 0; row < results_per_quadgroup; row++) {
result[row] = quad_sum(result[row]);
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
y[row * quads_per_simd] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_fast_impl(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack<32>();
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
load_vector<T, U, values_per_thread>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, typename S>
METAL_FUNC void mxfp4_qmv_impl(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack<32>();
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
if (out_row >= out_vec_size) {
return;
}
// In this case we need to properly guard all our reads because there isn't
// even 1 tile in the matrix
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
ws +=
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
int k = 0;
for (; k < in_vec_size - block_size; k += block_size) {
load_vector<T, U, values_per_thread>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
S s = sl[0];
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
x += block_size;
}
const int remaining = clamp(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
if (remaining > 0) {
load_vector_safe<T, U, values_per_thread>(x, x_thread, remaining);
for (int row = 0; out_row + row < out_vec_size; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
}
for (int row = 0; out_row + row < out_vec_size; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
// In this case the last tile is moved back to redo some output values
else {
ws += used_out_row * in_vec_size_w +
simd_lid * packs_per_thread * bytes_per_pack;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + used_out_row;
int k = 0;
for (; k < in_vec_size - block_size; k += block_size) {
load_vector<T, U, values_per_thread>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] += qdot<U, values_per_thread>(wl, x_thread, s, lut);
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
x += block_size;
}
const int remaining = clamp(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
if (remaining > 0) {
load_vector_safe<T, U, values_per_thread>(x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;
U s = dequantize_scale<U>(sl[0]);
result[row] +=
qdot_safe<U, values_per_thread>(wl, x_thread, s, lut, remaining);
}
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
}
template <typename T, const int group_size, typename S>
METAL_FUNC void mxfp4_qvm_impl(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const int in_vec_size,
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup float* lut) {
constexpr int num_simdgroups = 2;
constexpr int pack_factor = get_pack_factor<32>();
constexpr int bytes_per_pack = get_bytes_per_pack();
constexpr int tn = 32 / pack_factor;
constexpr int block_size = SIMD_SIZE;
using W_T = uint32_t;
const device W_T* ws = (const device W_T*)w;
typedef float U;
typedef struct {
W_T wi[tn * bytes_per_pack];
} vec_w;
thread vec_w w_local;
thread U result[tn * pack_factor] = {0};
thread U scale = 0;
thread U x_local = 0;
load_mxfp4_lut(lut, simd_gid, simd_lid);
// Adjust positions
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.x * in_vec_size + simd_lid;
y += tid.x * out_vec_size + out_col;
if (out_col >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of block_size
int remaining = in_vec_size % block_size;
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += block_size) {
x_local = *x;
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
x += block_size;
scales += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
}
} else {
for (int i = block_size; i < in_vec_size; i += block_size) {
x_local = *x;
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
x += block_size;
scales += block_size * out_vec_size_g;
ws += block_size * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = dequantize_scale<U>(*scales);
w_local = *((device vec_w*)ws);
} else {
x_local = 0;
scale = 0;
}
qouter<U, tn * pack_factor>(
(thread uint8_t*)&w_local, x_local, scale, result, lut);
}
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k = 0; k < tn * pack_factor; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k = 0; k < tn * pack_factor; k++) {
y[k] = static_cast<T>(result[k]);
}
}
}
template <
typename T,
const int group_size,
typename S,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_t_impl(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup T* lut) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
(void)lid;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = get_pack_factor<8>();
constexpr int bytes_per_pack = get_bytes_per_pack();
constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
S>;
// Set the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
auto wl = (const device uint8_t*)w;
x += y_row * static_cast<int64_t>(K);
wl += y_col * K_w;
scales += y_col * K_g;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
} else {
mma_op.store_result(y, N);
}
}
template <
typename T,
const int group_size,
typename S,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void mxfp4_qmm_n_impl(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
threadgroup T* lut) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
(void)lid;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = get_pack_factor<8>();
constexpr int bytes_per_pack = get_bytes_per_pack();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
using loader_x_t = mlx::steel::
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
using loader_w_t = QuantizedBlockLoader<
T,
BK,
BN,
BN_padded,
0,
WM * WN * SIMD_SIZE,
group_size,
S>;
auto wl = (const device uint8_t*)w;
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * static_cast<int64_t>(K);
wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
if ((K % BK) != 0) {
const int k_blocks = K / BK;
for (int k = 0; k < k_blocks; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
const short num_k = K - k_blocks * BK;
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(num_k, num_els));
loader_w.load_safe(short2(BN, num_k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if ((K % BK) != 0) {
const int k_blocks = K / BK;
for (int k = 0; k < k_blocks; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
const short num_k = K - k_blocks * BK;
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(num_k, BM));
loader_w.load_safe(short2(BN, num_k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
} else {
mma_op.store_result(y, N);
}
}
template <typename T, typename S>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
} else {
ulong2 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
}
y += tid.z * output_stride;
}
template <typename T, typename S>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device S*& scales,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T*& y,
int output_stride,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx;
uint32_t w_idx;
if (batch_ndims == 1) {
x_idx = lhs_indices[tid.z * lhs_strides[0]];
w_idx = rhs_indices[tid.z * rhs_strides[0]];
} else {
ulong2 idx = elem_to_loc_broadcast(
tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
x_idx = lhs_indices[idx.x];
w_idx = rhs_indices[idx.y];
}
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
} else {
ulong2 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
}
y += tid.z * output_stride;
}
template <typename T, int group_size, typename S, int D, bool batched>
[[kernel]] void mxfp4_qmv_quad(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
y,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_quad_impl<T, group_size, S, D>(
w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid,
simd_gid,
simd_lid,
lut);
}
template <typename T, int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
y,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qmv(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
y,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, typename S, bool batched>
[[kernel]] void mxfp4_qvm(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
y,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, const int group_size, typename S, int split_k = 32>
[[kernel]] void mxfp4_qvm_split_k(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int& final_block_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
y,
out_vec_size * M,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
// When (in_vec_size % split_k != 0) the final block needs to be smaller
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w,
scales,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
simd_lid,
lut);
}
template <
typename T,
const int group_size,
typename S,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_t(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& K,
const constant int& N,
const constant int& M,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
threadgroup T lut[16];
if (batched) {
adjust_matrix_offsets(
x,
w,
scales,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_qmm_n(
const device uint32_t* w,
const device S* scales,
const device T* x,
device T* y,
const constant int& K,
const constant int& N,
const constant int& M,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
threadgroup T lut[16];
if (batched) {
adjust_matrix_offsets(
x,
w,
scales,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv_fast(
const device uint32_t* w,
const device S* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
lhs_indices,
rhs_indices,
y,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_fast_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qmv(
const device uint32_t* w,
const device S* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
lhs_indices,
rhs_indices,
y,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qmv_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <typename T, int group_size, typename S>
[[kernel]] void mxfp4_gather_qvm(
const device uint32_t* w,
const device S* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
int M = x_shape[x_batch_ndims];
adjust_matrix_offsets(
x,
w,
scales,
lhs_indices,
rhs_indices,
y,
out_vec_size * M,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
threadgroup float lut[16];
mxfp4_qvm_impl<T, group_size>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_t(
const device uint32_t* w,
const device S* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T* y,
const constant int& K,
const constant int& N,
const constant int& M,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
threadgroup T lut[16];
adjust_matrix_offsets(
x,
w,
scales,
lhs_indices,
rhs_indices,
y,
M * N,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
mxfp4_qmm_t_impl<T, group_size, S, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
const int group_size,
typename S,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void mxfp4_gather_qmm_n(
const device uint32_t* w,
const device S* scales,
const device T* x,
const device uint32_t* lhs_indices,
const device uint32_t* rhs_indices,
device T* y,
const constant int& K,
const constant int& N,
const constant int& M,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
const constant int& batch_ndims,
const constant int* batch_shape,
const constant int64_t* lhs_strides,
const constant int64_t* rhs_strides,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
threadgroup T lut[16];
adjust_matrix_offsets(
x,
w,
scales,
lhs_indices,
rhs_indices,
y,
M * N,
batch_ndims,
batch_shape,
lhs_strides,
rhs_strides,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
mxfp4_qmm_n_impl<T, group_size, S, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
}
template <
typename T,
int group_size,
typename S,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose>
[[kernel]] void mxfp4_gather_qmm_rhs(
const device T* x,
const device uint32_t* w,
const device S* scales,
const device uint32_t* indices,
device T* y,
const constant int& M,
const constant int& N,
const constant int& K,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
constexpr int pack_factor = get_pack_factor<8>();
constexpr int bytes_per_pack = get_bytes_per_pack();
constexpr int BK_padded = (BK + 16 / sizeof(T));
constexpr int BN_padded = (BN + 16 / sizeof(T));
threadgroup T lut[16];
using mma_t = mlx::steel::BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
false,
transpose,
BK_padded,
transpose ? BK_padded : BN_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_w_t = QuantizedBlockLoader<
T,
transpose ? BN : BK,
transpose ? BK : BN,
transpose ? BK_padded : BN_padded,
transpose,
WM * WN * SIMD_SIZE,
group_size,
S>;
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
// Compute the block
const int K_w = K * bytes_per_pack / pack_factor;
const int K_g = K / group_size;
const int N_w = N * bytes_per_pack / pack_factor;
const int N_g = N / group_size;
const int K_it = K / BK;
const size_t stride_w = transpose ? N * K_w : K * N_w;
const size_t stride_s = transpose ? N * K_g : K * N_g;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const size_t y_row_long = size_t(y_row);
const size_t y_col_long = size_t(y_col);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
// Calculate the final tiles in the case that K is not aligned
const int k_remain = K - K_it * BK;
const short2 tile_x = short2(k_remain, tgp_bm);
const short2 tile_w =
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
// Move x and output to the correct block
auto wl = (const device uint8_t*)w;
x += y_row_long * K;
y += y_row_long * N + y_col_long;
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
scales += transpose ? y_col_long * K_g : y_col / group_size;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = indices[y_row];
short offset_next = 0;
int n = 0;
while (n < tgp_bm) {
n++;
offset = offset_next;
index = index_next;
offset_next = tgp_bm;
for (; n < tgp_bm; n++) {
if (indices[y_row + n] != index) {
offset_next = n;
index_next = indices[y_row + n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
thread loader_w_t loader_w(
wl + index * stride_w,
scales + index * stride_s,
transpose ? K : N,
Ws,
lut,
simd_group_id,
simd_lane_id);
// Matrices are all aligned check nothing
if (align_M && align_N) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
} else {
// Tile aligned so check outside of the hot loop
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
// Store results to device memory
if (offset_next - offset == BM) {
mma_op.store_result(y, N);
} else {
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
}
// Tile partially aligned check rows
else if (align_N || tgp_bn == BN) {
gemm_loop_unaligned<false, true, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(BN, offset_next));
}
// Tile partially aligned check cols
else if (align_M || tgp_bm == BM) {
gemm_loop_unaligned<true, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
// Nothing aligned so check both rows and cols
else {
gemm_loop_unaligned<false, false, transpose>(
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
if (!align_K) {
threadgroup_barrier(mem_flags::mem_threadgroup);
gemm_loop_finalize(
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
}
mma_op.store_result_slice(
y, N, short2(0, offset), short2(tgp_bn, offset_next));
}
}
}
}