cpu mxfp4

This commit is contained in:
Awni Hannun 2025-08-20 17:18:47 -07:00
parent 51449428dd
commit 8da1c64fe9
3 changed files with 379 additions and 136 deletions

View File

@ -13,6 +13,35 @@ namespace mlx::core {
namespace {
const static float MXFP4_LUT[16] = {
+0.0f,
+0.5f,
+1.0f,
+1.5f,
+2.0f,
+3.0f,
+4.0f,
+6.0f,
-0.0f,
-0.5f,
-1.0f,
-1.5f,
-2.0f,
-3.0f,
-4.0f,
-6.0f};
template <typename T>
static inline T dequantize_scale(uint8_t s) {
using FOrI = union {
bfloat16_t f;
uint16_t i;
};
FOrI out;
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
return static_cast<T>(out.f);
}
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
@ -407,50 +436,230 @@ void _qmm_dispatch(
}
}
// template <typename T>
// void _qmm_mxfp4_dispatch_typed(
// array& out,
// const array& x,
// const array& w,
// const array& scales,
// bool transposed_w) {
// int K = x.shape(-1);
// int M = x.ndim() > 1 ? x.shape(-2) : 1;
// int N = out.shape(-1);
// int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
// int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
// int batch_size = x.size() / (K * M);
//
// auto out_ptr = out.data<T>();
// auto x_ptr = x.data<T>();
// auto w_ptr = w.data<uint32_t>();
// auto scales_ptr = scales.data<T>();
// for (int i = 0; i < batch_size; i++) {
// _qmm_mxfp4_dispatch_typed<T>(
// out_ptr + i * M * N,
// x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
// w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
// scales_ptr + elem_to_loc(i * g_els, scales.shape(),
// scales.strides()), M, N, K, transposed_w);
// }
// }
//
//
// void _qmm_mxfp4_dispatch(
// array& out,
// const array& x,
// const array& w,
// const array& scales,
// bool transposed_w) {
// switch (x.dtype()) {
// case bfloat16:
// _qmm_mxfp4_dispatch_typed<bfloat16>(out, x, w, scales, transposed_w);
// break;
// default:
// throw std::invalid_argument(
// "[quantized_matmul] only bfloat is supported for mxfp4");
// }
// }
template <typename T>
void mxfp4_qmm(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
for (int ng = 0; ng < packs_in_group; ng++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
}
}
result += N;
}
}
template <typename T>
void mxfp4_qmm_t(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
T gsum = 0;
for (int kw = 0; kw < packs_in_group; kw++) {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
wi >>= 4;
}
}
sum += scale * gsum;
}
*result = sum;
result++;
}
x += K;
}
}
template <int S>
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
if constexpr (S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & 0xf;
simd::Simd<float, S> w_out;
for (int i = 0; i < S; ++i) {
w_out[i] = MXFP4_LUT[wi[i]];
}
return w_out;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
}
template <typename T>
void mxfp4_qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K) {
constexpr int group_size = 32;
constexpr int pack_factor = 32 / 4;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const uint8_t* scales_local = scales;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = dequantize_scale<T>(*scales_local++);
simd::Simd<float, S> g_acc(0);
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
// Extract bits
auto wf = mxfp4_extract_bits_simd<S>(w_local);
w_local += packs_per_simd;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
g_acc = g_acc + x_simd * wf;
x_local += S;
}
acc = acc + scale * g_acc;
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T>
void mxfp4_qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const uint8_t* scales,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (simd::max_size<T> % 8 == 0) {
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
} else {
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
}
} else {
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
}
}
template <typename T>
void mxfp4_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
for (int i = 0; i < batch_size; i++) {
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
bool transposed_w) {
switch (x.dtype()) {
case bfloat16:
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
break;
case float16:
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
break;
case float32:
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
@ -558,6 +767,74 @@ void _bs_qmm_dispatch(
}
}
template <typename T>
void mxfp4_bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<uint8_t>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
mxfp4_qmm_dispatch_transpose<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
M,
N,
K,
transposed_w);
}
}
void mxfp4_bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& lhs_indices,
const array& rhs_indices,
bool transposed_w) {
switch (x.dtype()) {
case float32:
mxfp4_bs_qmm_dispatch_typed<float>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case float16:
mxfp4_bs_qmm_dispatch_typed<float16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
case bfloat16:
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -604,15 +881,13 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
// encoder.dispatch([out = array::unsafe_weak_copy(out),
// x = array::unsafe_weak_copy(x),
// w = array::unsafe_weak_copy(w),
// scales = array::unsafe_weak_copy(scales),
// group_size_ = group_size_,
// bits_ = bits_,
// transpose_ = transpose_]() mutable {
// _qmm_mxfp4_dispatch(out, x, w, scales, transpose_);
// });
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
transpose_ = transpose_]() mutable {
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
});
}
}
@ -622,9 +897,8 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto& lhs_indices = inputs[inputs.size() - 2];
auto& rhs_indices = inputs[inputs.size() - 1];
std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(),
@ -643,7 +917,6 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
@ -652,32 +925,46 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
if (mode_ == "affine") {
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
} else {
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
transpose_ = transpose_]() mutable {
mxfp4_bs_qmm_dispatch(
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
});
}
}
template <typename T, typename U>

View File

@ -95,10 +95,10 @@ inline U qdot(
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]);
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
}
return scale * accum;
}
@ -115,10 +115,10 @@ inline U qdot_safe(
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] +
x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] +
x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]);
(x_thread[4 * i] * lut[ws[i] & 0xf] +
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
}
return scale * accum;
}
@ -131,8 +131,8 @@ inline void qouter(
thread U* result,
const threadgroup U* lut) {
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * scale * lut[w[i] & 0x0f];
result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4];
result[2 * i] += x * scale * lut[w[i] & 0xf];
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
}
}
@ -143,8 +143,8 @@ inline void dequantize(
threadgroup U* w_local,
const threadgroup U* lut) {
for (int i = 0; i < (N / 2); i++) {
w_local[2 * i] = scale * lut[w[i] & 0x0f];
w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4];
w_local[2 * i] = scale * lut[w[i] & 0xf];
w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf];
}
}

View File

@ -762,50 +762,6 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
}
array pack_and_quantize(
array& packed_w,
const array& scales,
const array& biases,
int bits,
const Stream& s) {
int el_per_int = 32 / bits;
array zero(0, packed_w.dtype());
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins,
s),
uint32,
s);
if (is_power_of_2(bits)) {
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
} else {
// This is slow but we have fast GPU/CPU versions of this function so we
// shouldn't be here often.
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
packed_w = bitwise_and(
right_shift(packed_w, arange(bits, uint32, s), s),
array({1}, uint32),
s);
auto new_shape = packed_w.shape();
new_shape[new_shape.size() - 2] = -1;
new_shape.back() = 32;
packed_w = reshape(packed_w, new_shape, s);
array shifts = arange(32, uint32, s);
packed_w =
sum(left_shift(packed_w, shifts, s),
/* axis= */ -1,
/* keepdims= */ false,
s);
}
return packed_w;
}
bool Quantize::is_equivalent(const Primitive& other) const {
const Quantize& p_other = static_cast<const Quantize&>(other);
return (