mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 00:35:27 +08:00
cpu mxfp4
This commit is contained in:
parent
51449428dd
commit
8da1c64fe9
@ -13,6 +13,35 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
const static float MXFP4_LUT[16] = {
|
||||||
|
+0.0f,
|
||||||
|
+0.5f,
|
||||||
|
+1.0f,
|
||||||
|
+1.5f,
|
||||||
|
+2.0f,
|
||||||
|
+3.0f,
|
||||||
|
+4.0f,
|
||||||
|
+6.0f,
|
||||||
|
-0.0f,
|
||||||
|
-0.5f,
|
||||||
|
-1.0f,
|
||||||
|
-1.5f,
|
||||||
|
-2.0f,
|
||||||
|
-3.0f,
|
||||||
|
-4.0f,
|
||||||
|
-6.0f};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline T dequantize_scale(uint8_t s) {
|
||||||
|
using FOrI = union {
|
||||||
|
bfloat16_t f;
|
||||||
|
uint16_t i;
|
||||||
|
};
|
||||||
|
FOrI out;
|
||||||
|
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
|
||||||
|
return static_cast<T>(out.f);
|
||||||
|
}
|
||||||
|
|
||||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||||
}
|
}
|
||||||
@ -407,50 +436,230 @@ void _qmm_dispatch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// template <typename T>
|
template <typename T>
|
||||||
// void _qmm_mxfp4_dispatch_typed(
|
void mxfp4_qmm(
|
||||||
// array& out,
|
T* result,
|
||||||
// const array& x,
|
const T* x,
|
||||||
// const array& w,
|
const uint32_t* w,
|
||||||
// const array& scales,
|
const uint8_t* scales,
|
||||||
// bool transposed_w) {
|
int M,
|
||||||
// int K = x.shape(-1);
|
int N,
|
||||||
// int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
int K) {
|
||||||
// int N = out.shape(-1);
|
constexpr int group_size = 32;
|
||||||
// int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||||
// int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||||
// int batch_size = x.size() / (K * M);
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
//
|
|
||||||
// auto out_ptr = out.data<T>();
|
for (int m = 0; m < M; m++) {
|
||||||
// auto x_ptr = x.data<T>();
|
const uint8_t* w_local = (const uint8_t*)w;
|
||||||
// auto w_ptr = w.data<uint32_t>();
|
const uint8_t* scales_local = scales;
|
||||||
// auto scales_ptr = scales.data<T>();
|
|
||||||
// for (int i = 0; i < batch_size; i++) {
|
std::fill(result, result + N, 0);
|
||||||
// _qmm_mxfp4_dispatch_typed<T>(
|
|
||||||
// out_ptr + i * M * N,
|
for (int k = 0; k < K; k++) {
|
||||||
// x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
T* result_local = result;
|
||||||
// w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
T xi = *x++;
|
||||||
// scales_ptr + elem_to_loc(i * g_els, scales.shape(),
|
|
||||||
// scales.strides()), M, N, K, transposed_w);
|
for (int n = 0; n < N; n += group_size) {
|
||||||
// }
|
T scale = dequantize_scale<T>(*scales_local++);
|
||||||
// }
|
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||||
//
|
uint8_t wi = *w_local++;
|
||||||
//
|
#pragma clang loop unroll(full)
|
||||||
// void _qmm_mxfp4_dispatch(
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
// array& out,
|
(*result_local++) +=
|
||||||
// const array& x,
|
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||||
// const array& w,
|
wi >>= 4;
|
||||||
// const array& scales,
|
}
|
||||||
// bool transposed_w) {
|
}
|
||||||
// switch (x.dtype()) {
|
}
|
||||||
// case bfloat16:
|
}
|
||||||
// _qmm_mxfp4_dispatch_typed<bfloat16>(out, x, w, scales, transposed_w);
|
|
||||||
// break;
|
result += N;
|
||||||
// default:
|
}
|
||||||
// throw std::invalid_argument(
|
}
|
||||||
// "[quantized_matmul] only bfloat is supported for mxfp4");
|
|
||||||
// }
|
template <typename T>
|
||||||
// }
|
void mxfp4_qmm_t(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const uint8_t* scales,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int group_size = 32;
|
||||||
|
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||||
|
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||||
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
const uint8_t* w_local = (const uint8_t*)w;
|
||||||
|
const uint8_t* scales_local = scales;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
const T* x_local = x;
|
||||||
|
T sum = 0;
|
||||||
|
for (int k = 0; k < K; k += group_size) {
|
||||||
|
T scale = dequantize_scale<T>(*scales_local++);
|
||||||
|
|
||||||
|
T gsum = 0;
|
||||||
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
|
uint8_t wi = *w_local++;
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
|
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||||
|
wi >>= 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sum += scale * gsum;
|
||||||
|
}
|
||||||
|
*result = sum;
|
||||||
|
result++;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int S>
|
||||||
|
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
|
||||||
|
if constexpr (S == 8) {
|
||||||
|
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
|
||||||
|
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||||
|
auto wi = simd::Simd<uint32_t, S>(*w);
|
||||||
|
wi = wi >> shifts;
|
||||||
|
wi = wi & 0xf;
|
||||||
|
simd::Simd<float, S> w_out;
|
||||||
|
for (int i = 0; i < S; ++i) {
|
||||||
|
w_out[i] = MXFP4_LUT[wi[i]];
|
||||||
|
}
|
||||||
|
return w_out;
|
||||||
|
} else {
|
||||||
|
// Appease compiler.. but should never get here
|
||||||
|
throw std::runtime_error("Unsupported combination for simd qmm.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void mxfp4_qmm_t_simd(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const uint8_t* scales,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K) {
|
||||||
|
constexpr int group_size = 32;
|
||||||
|
constexpr int pack_factor = 32 / 4;
|
||||||
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
constexpr int S = simd::max_size<T>;
|
||||||
|
static_assert(
|
||||||
|
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
|
||||||
|
constexpr int packs_per_simd = S / pack_factor;
|
||||||
|
|
||||||
|
for (int m = 0; m < M; m++) {
|
||||||
|
const uint32_t* w_local = w;
|
||||||
|
const uint8_t* scales_local = scales;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
simd::Simd<float, S> acc(0);
|
||||||
|
auto x_local = x;
|
||||||
|
for (int k = 0; k < K; k += group_size) {
|
||||||
|
T scale = dequantize_scale<T>(*scales_local++);
|
||||||
|
|
||||||
|
simd::Simd<float, S> g_acc(0);
|
||||||
|
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
|
||||||
|
// Extract bits
|
||||||
|
auto wf = mxfp4_extract_bits_simd<S>(w_local);
|
||||||
|
w_local += packs_per_simd;
|
||||||
|
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
|
||||||
|
g_acc = g_acc + x_simd * wf;
|
||||||
|
x_local += S;
|
||||||
|
}
|
||||||
|
acc = acc + scale * g_acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
*result = T(simd::sum(acc));
|
||||||
|
result++;
|
||||||
|
}
|
||||||
|
x += K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void mxfp4_qmm_dispatch_transpose(
|
||||||
|
T* result,
|
||||||
|
const T* x,
|
||||||
|
const uint32_t* w,
|
||||||
|
const uint8_t* scales,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
bool transposed_w) {
|
||||||
|
if (transposed_w) {
|
||||||
|
// the simd size must be a multiple of the number of elements per word
|
||||||
|
if constexpr (simd::max_size<T> % 8 == 0) {
|
||||||
|
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
|
||||||
|
} else {
|
||||||
|
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void mxfp4_qmm_dispatch_typed(
|
||||||
|
array& out,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
bool transposed_w) {
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||||
|
int N = out.shape(-1);
|
||||||
|
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||||
|
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||||
|
int batch_size = x.size() / (K * M);
|
||||||
|
|
||||||
|
auto out_ptr = out.data<T>();
|
||||||
|
auto x_ptr = x.data<T>();
|
||||||
|
auto w_ptr = w.data<uint32_t>();
|
||||||
|
auto scales_ptr = scales.data<uint8_t>();
|
||||||
|
for (int i = 0; i < batch_size; i++) {
|
||||||
|
mxfp4_qmm_dispatch_transpose<T>(
|
||||||
|
out_ptr + i * M * N,
|
||||||
|
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||||
|
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||||
|
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
transposed_w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void mxfp4_qmm_dispatch(
|
||||||
|
array& out,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
bool transposed_w) {
|
||||||
|
switch (x.dtype()) {
|
||||||
|
case bfloat16:
|
||||||
|
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[quantized_matmul] only floating types are supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void _bs_qmm_dispatch_typed(
|
void _bs_qmm_dispatch_typed(
|
||||||
@ -558,6 +767,74 @@ void _bs_qmm_dispatch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void mxfp4_bs_qmm_dispatch_typed(
|
||||||
|
array& out,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& lhs_indices,
|
||||||
|
const array& rhs_indices,
|
||||||
|
bool transposed_w) {
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = x.shape(-2);
|
||||||
|
int N = out.shape(-1);
|
||||||
|
|
||||||
|
int w_els = w.shape(-1) * w.shape(-2);
|
||||||
|
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||||
|
|
||||||
|
auto out_ptr = out.data<T>();
|
||||||
|
auto x_ptr = x.data<T>();
|
||||||
|
auto w_ptr = w.data<uint32_t>();
|
||||||
|
auto scales_ptr = scales.data<uint8_t>();
|
||||||
|
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||||
|
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||||
|
|
||||||
|
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||||
|
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||||
|
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||||
|
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||||
|
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||||
|
mxfp4_qmm_dispatch_transpose<T>(
|
||||||
|
out_ptr + i * M * N,
|
||||||
|
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||||
|
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||||
|
scales_ptr +
|
||||||
|
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
transposed_w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void mxfp4_bs_qmm_dispatch(
|
||||||
|
array& out,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& lhs_indices,
|
||||||
|
const array& rhs_indices,
|
||||||
|
bool transposed_w) {
|
||||||
|
switch (x.dtype()) {
|
||||||
|
case float32:
|
||||||
|
mxfp4_bs_qmm_dispatch_typed<float>(
|
||||||
|
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
mxfp4_bs_qmm_dispatch_typed<float16_t>(
|
||||||
|
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||||
|
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[quantized_matmul] only floating types are supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -604,15 +881,13 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// encoder.dispatch([out = array::unsafe_weak_copy(out),
|
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||||
// x = array::unsafe_weak_copy(x),
|
x = array::unsafe_weak_copy(x),
|
||||||
// w = array::unsafe_weak_copy(w),
|
w = array::unsafe_weak_copy(w),
|
||||||
// scales = array::unsafe_weak_copy(scales),
|
scales = array::unsafe_weak_copy(scales),
|
||||||
// group_size_ = group_size_,
|
transpose_ = transpose_]() mutable {
|
||||||
// bits_ = bits_,
|
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
|
||||||
// transpose_ = transpose_]() mutable {
|
});
|
||||||
// _qmm_mxfp4_dispatch(out, x, w, scales, transpose_);
|
|
||||||
// });
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -622,9 +897,8 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& x_pre = inputs[0];
|
auto& x_pre = inputs[0];
|
||||||
auto& w_pre = inputs[1];
|
auto& w_pre = inputs[1];
|
||||||
auto& scales_pre = inputs[2];
|
auto& scales_pre = inputs[2];
|
||||||
auto& biases_pre = inputs[3];
|
auto& lhs_indices = inputs[inputs.size() - 2];
|
||||||
auto& lhs_indices = inputs[4];
|
auto& rhs_indices = inputs[inputs.size() - 1];
|
||||||
auto& rhs_indices = inputs[5];
|
|
||||||
|
|
||||||
std::vector<array> temps;
|
std::vector<array> temps;
|
||||||
auto ensure_row_contiguous_last_dims = [s = stream(),
|
auto ensure_row_contiguous_last_dims = [s = stream(),
|
||||||
@ -643,7 +917,6 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
@ -652,32 +925,46 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
encoder.set_input_array(w);
|
encoder.set_input_array(w);
|
||||||
encoder.set_input_array(scales);
|
encoder.set_input_array(scales);
|
||||||
encoder.set_input_array(biases);
|
|
||||||
encoder.set_input_array(lhs_indices);
|
encoder.set_input_array(lhs_indices);
|
||||||
encoder.set_input_array(rhs_indices);
|
encoder.set_input_array(rhs_indices);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
if (mode_ == "affine") {
|
||||||
x = array::unsafe_weak_copy(x),
|
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
|
||||||
w = array::unsafe_weak_copy(w),
|
encoder.set_input_array(biases);
|
||||||
scales = array::unsafe_weak_copy(scales),
|
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||||
biases = array::unsafe_weak_copy(biases),
|
x = array::unsafe_weak_copy(x),
|
||||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
w = array::unsafe_weak_copy(w),
|
||||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
scales = array::unsafe_weak_copy(scales),
|
||||||
group_size_ = group_size_,
|
biases = array::unsafe_weak_copy(biases),
|
||||||
bits_ = bits_,
|
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||||
transpose_ = transpose_]() mutable {
|
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||||
_bs_qmm_dispatch(
|
group_size_ = group_size_,
|
||||||
out,
|
bits_ = bits_,
|
||||||
x,
|
transpose_ = transpose_]() mutable {
|
||||||
w,
|
_bs_qmm_dispatch(
|
||||||
scales,
|
out,
|
||||||
biases,
|
x,
|
||||||
lhs_indices,
|
w,
|
||||||
rhs_indices,
|
scales,
|
||||||
group_size_,
|
biases,
|
||||||
bits_,
|
lhs_indices,
|
||||||
transpose_);
|
rhs_indices,
|
||||||
});
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
transpose_);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||||
|
x = array::unsafe_weak_copy(x),
|
||||||
|
w = array::unsafe_weak_copy(w),
|
||||||
|
scales = array::unsafe_weak_copy(scales),
|
||||||
|
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||||
|
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||||
|
transpose_ = transpose_]() mutable {
|
||||||
|
mxfp4_bs_qmm_dispatch(
|
||||||
|
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
|
@ -95,10 +95,10 @@ inline U qdot(
|
|||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum +=
|
accum +=
|
||||||
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
(x_thread[4 * i] * lut[ws[i] & 0xf] +
|
||||||
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] +
|
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
|
||||||
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] +
|
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
|
||||||
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]);
|
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
|
||||||
}
|
}
|
||||||
return scale * accum;
|
return scale * accum;
|
||||||
}
|
}
|
||||||
@ -115,10 +115,10 @@ inline U qdot_safe(
|
|||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
accum +=
|
accum +=
|
||||||
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
(x_thread[4 * i] * lut[ws[i] & 0xf] +
|
||||||
x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] +
|
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
|
||||||
x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] +
|
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
|
||||||
x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]);
|
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
|
||||||
}
|
}
|
||||||
return scale * accum;
|
return scale * accum;
|
||||||
}
|
}
|
||||||
@ -131,8 +131,8 @@ inline void qouter(
|
|||||||
thread U* result,
|
thread U* result,
|
||||||
const threadgroup U* lut) {
|
const threadgroup U* lut) {
|
||||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||||
result[2 * i] += x * scale * lut[w[i] & 0x0f];
|
result[2 * i] += x * scale * lut[w[i] & 0xf];
|
||||||
result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4];
|
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,8 +143,8 @@ inline void dequantize(
|
|||||||
threadgroup U* w_local,
|
threadgroup U* w_local,
|
||||||
const threadgroup U* lut) {
|
const threadgroup U* lut) {
|
||||||
for (int i = 0; i < (N / 2); i++) {
|
for (int i = 0; i < (N / 2); i++) {
|
||||||
w_local[2 * i] = scale * lut[w[i] & 0x0f];
|
w_local[2 * i] = scale * lut[w[i] & 0xf];
|
||||||
w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4];
|
w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
44
mlx/fast.cpp
44
mlx/fast.cpp
@ -762,50 +762,6 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
|||||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
||||||
}
|
}
|
||||||
|
|
||||||
array pack_and_quantize(
|
|
||||||
array& packed_w,
|
|
||||||
const array& scales,
|
|
||||||
const array& biases,
|
|
||||||
int bits,
|
|
||||||
const Stream& s) {
|
|
||||||
int el_per_int = 32 / bits;
|
|
||||||
array zero(0, packed_w.dtype());
|
|
||||||
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
|
|
||||||
packed_w = astype(
|
|
||||||
clip(
|
|
||||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
|
||||||
zero,
|
|
||||||
n_bins,
|
|
||||||
s),
|
|
||||||
uint32,
|
|
||||||
s);
|
|
||||||
if (is_power_of_2(bits)) {
|
|
||||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
|
||||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
|
||||||
packed_w = sum(
|
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
|
||||||
} else {
|
|
||||||
// This is slow but we have fast GPU/CPU versions of this function so we
|
|
||||||
// shouldn't be here often.
|
|
||||||
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
|
|
||||||
packed_w = bitwise_and(
|
|
||||||
right_shift(packed_w, arange(bits, uint32, s), s),
|
|
||||||
array({1}, uint32),
|
|
||||||
s);
|
|
||||||
auto new_shape = packed_w.shape();
|
|
||||||
new_shape[new_shape.size() - 2] = -1;
|
|
||||||
new_shape.back() = 32;
|
|
||||||
packed_w = reshape(packed_w, new_shape, s);
|
|
||||||
array shifts = arange(32, uint32, s);
|
|
||||||
packed_w =
|
|
||||||
sum(left_shift(packed_w, shifts, s),
|
|
||||||
/* axis= */ -1,
|
|
||||||
/* keepdims= */ false,
|
|
||||||
s);
|
|
||||||
}
|
|
||||||
return packed_w;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Quantize::is_equivalent(const Primitive& other) const {
|
bool Quantize::is_equivalent(const Primitive& other) const {
|
||||||
const Quantize& p_other = static_cast<const Quantize&>(other);
|
const Quantize& p_other = static_cast<const Quantize&>(other);
|
||||||
return (
|
return (
|
||||||
|
Loading…
Reference in New Issue
Block a user