mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-27 08:46:41 +08:00
cpu mxfp4
This commit is contained in:
parent
51449428dd
commit
8da1c64fe9
@ -13,6 +13,35 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
const static float MXFP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
template <typename T>
|
||||
static inline T dequantize_scale(uint8_t s) {
|
||||
using FOrI = union {
|
||||
bfloat16_t f;
|
||||
uint16_t i;
|
||||
};
|
||||
FOrI out;
|
||||
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
|
||||
return static_cast<T>(out.f);
|
||||
}
|
||||
|
||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
@ -407,50 +436,230 @@ void _qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename T>
|
||||
// void _qmm_mxfp4_dispatch_typed(
|
||||
// array& out,
|
||||
// const array& x,
|
||||
// const array& w,
|
||||
// const array& scales,
|
||||
// bool transposed_w) {
|
||||
// int K = x.shape(-1);
|
||||
// int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
// int N = out.shape(-1);
|
||||
// int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
// int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
// int batch_size = x.size() / (K * M);
|
||||
//
|
||||
// auto out_ptr = out.data<T>();
|
||||
// auto x_ptr = x.data<T>();
|
||||
// auto w_ptr = w.data<uint32_t>();
|
||||
// auto scales_ptr = scales.data<T>();
|
||||
// for (int i = 0; i < batch_size; i++) {
|
||||
// _qmm_mxfp4_dispatch_typed<T>(
|
||||
// out_ptr + i * M * N,
|
||||
// x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
// w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
// scales_ptr + elem_to_loc(i * g_els, scales.shape(),
|
||||
// scales.strides()), M, N, K, transposed_w);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
//
|
||||
// void _qmm_mxfp4_dispatch(
|
||||
// array& out,
|
||||
// const array& x,
|
||||
// const array& w,
|
||||
// const array& scales,
|
||||
// bool transposed_w) {
|
||||
// switch (x.dtype()) {
|
||||
// case bfloat16:
|
||||
// _qmm_mxfp4_dispatch_typed<bfloat16>(out, x, w, scales, transposed_w);
|
||||
// break;
|
||||
// default:
|
||||
// throw std::invalid_argument(
|
||||
// "[quantized_matmul] only bfloat is supported for mxfp4");
|
||||
// }
|
||||
// }
|
||||
template <typename T>
|
||||
void mxfp4_qmm(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
std::fill(result, result + N, 0);
|
||||
|
||||
for (int k = 0; k < K; k++) {
|
||||
T* result_local = result;
|
||||
T xi = *x++;
|
||||
|
||||
for (int n = 0; n < N; n += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||
wi >>= 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result += N;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_t(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const T* x_local = x;
|
||||
T sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
|
||||
T gsum = 0;
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||
wi >>= 4;
|
||||
}
|
||||
}
|
||||
sum += scale * gsum;
|
||||
}
|
||||
*result = sum;
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <int S>
|
||||
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
|
||||
if constexpr (S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
auto wi = simd::Simd<uint32_t, S>(*w);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & 0xf;
|
||||
simd::Simd<float, S> w_out;
|
||||
for (int i = 0; i < S; ++i) {
|
||||
w_out[i] = MXFP4_LUT[wi[i]];
|
||||
}
|
||||
return w_out;
|
||||
} else {
|
||||
// Appease compiler.. but should never get here
|
||||
throw std::runtime_error("Unsupported combination for simd qmm.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_t_simd(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = 32 / 4;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
constexpr int S = simd::max_size<T>;
|
||||
static_assert(
|
||||
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
|
||||
constexpr int packs_per_simd = S / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
simd::Simd<float, S> acc(0);
|
||||
auto x_local = x;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
|
||||
simd::Simd<float, S> g_acc(0);
|
||||
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
|
||||
// Extract bits
|
||||
auto wf = mxfp4_extract_bits_simd<S>(w_local);
|
||||
w_local += packs_per_simd;
|
||||
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
|
||||
g_acc = g_acc + x_simd * wf;
|
||||
x_local += S;
|
||||
}
|
||||
acc = acc + scale * g_acc;
|
||||
}
|
||||
|
||||
*result = T(simd::sum(acc));
|
||||
result++;
|
||||
}
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_dispatch_transpose(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
// the simd size must be a multiple of the number of elements per word
|
||||
if constexpr (simd::max_size<T> % 8 == 0) {
|
||||
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
|
||||
} else {
|
||||
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
|
||||
}
|
||||
} else {
|
||||
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<uint8_t>();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
mxfp4_qmm_dispatch_transpose<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void mxfp4_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case bfloat16:
|
||||
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
case float32:
|
||||
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _bs_qmm_dispatch_typed(
|
||||
@ -558,6 +767,74 @@ void _bs_qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<uint8_t>();
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||
mxfp4_qmm_dispatch_transpose<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||
scales_ptr +
|
||||
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void mxfp4_bs_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
mxfp4_bs_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
mxfp4_bs_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -604,15 +881,13 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
} else {
|
||||
// encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
// x = array::unsafe_weak_copy(x),
|
||||
// w = array::unsafe_weak_copy(w),
|
||||
// scales = array::unsafe_weak_copy(scales),
|
||||
// group_size_ = group_size_,
|
||||
// bits_ = bits_,
|
||||
// transpose_ = transpose_]() mutable {
|
||||
// _qmm_mxfp4_dispatch(out, x, w, scales, transpose_);
|
||||
// });
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
transpose_ = transpose_]() mutable {
|
||||
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -622,9 +897,8 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
auto& lhs_indices = inputs[inputs.size() - 2];
|
||||
auto& rhs_indices = inputs[inputs.size() - 1];
|
||||
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous_last_dims = [s = stream(),
|
||||
@ -643,7 +917,6 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
@ -652,32 +925,46 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
if (mode_ == "affine") {
|
||||
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
transpose_ = transpose_]() mutable {
|
||||
mxfp4_bs_qmm_dispatch(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
|
@ -95,10 +95,10 @@ inline U qdot(
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]);
|
||||
(x_thread[4 * i] * lut[ws[i] & 0xf] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
|
||||
}
|
||||
return scale * accum;
|
||||
}
|
||||
@ -115,10 +115,10 @@ inline U qdot_safe(
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * lut[ws[i] & 0x000f] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]);
|
||||
(x_thread[4 * i] * lut[ws[i] & 0xf] +
|
||||
x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] +
|
||||
x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] +
|
||||
x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]);
|
||||
}
|
||||
return scale * accum;
|
||||
}
|
||||
@ -131,8 +131,8 @@ inline void qouter(
|
||||
thread U* result,
|
||||
const threadgroup U* lut) {
|
||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||
result[2 * i] += x * scale * lut[w[i] & 0x0f];
|
||||
result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4];
|
||||
result[2 * i] += x * scale * lut[w[i] & 0xf];
|
||||
result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf];
|
||||
}
|
||||
}
|
||||
|
||||
@ -143,8 +143,8 @@ inline void dequantize(
|
||||
threadgroup U* w_local,
|
||||
const threadgroup U* lut) {
|
||||
for (int i = 0; i < (N / 2); i++) {
|
||||
w_local[2 * i] = scale * lut[w[i] & 0x0f];
|
||||
w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4];
|
||||
w_local[2 * i] = scale * lut[w[i] & 0xf];
|
||||
w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf];
|
||||
}
|
||||
}
|
||||
|
||||
|
44
mlx/fast.cpp
44
mlx/fast.cpp
@ -762,50 +762,6 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
|
||||
}
|
||||
|
||||
array pack_and_quantize(
|
||||
array& packed_w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int bits,
|
||||
const Stream& s) {
|
||||
int el_per_int = 32 / bits;
|
||||
array zero(0, packed_w.dtype());
|
||||
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
|
||||
packed_w = astype(
|
||||
clip(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||
zero,
|
||||
n_bins,
|
||||
s),
|
||||
uint32,
|
||||
s);
|
||||
if (is_power_of_2(bits)) {
|
||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
} else {
|
||||
// This is slow but we have fast GPU/CPU versions of this function so we
|
||||
// shouldn't be here often.
|
||||
packed_w = expand_dims(packed_w, /* axis= */ -1, s);
|
||||
packed_w = bitwise_and(
|
||||
right_shift(packed_w, arange(bits, uint32, s), s),
|
||||
array({1}, uint32),
|
||||
s);
|
||||
auto new_shape = packed_w.shape();
|
||||
new_shape[new_shape.size() - 2] = -1;
|
||||
new_shape.back() = 32;
|
||||
packed_w = reshape(packed_w, new_shape, s);
|
||||
array shifts = arange(32, uint32, s);
|
||||
packed_w =
|
||||
sum(left_shift(packed_w, shifts, s),
|
||||
/* axis= */ -1,
|
||||
/* keepdims= */ false,
|
||||
s);
|
||||
}
|
||||
return packed_w;
|
||||
}
|
||||
|
||||
bool Quantize::is_equivalent(const Primitive& other) const {
|
||||
const Quantize& p_other = static_cast<const Quantize&>(other);
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user