mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)
* start to cleanup/unify accelerate and common back-ends * more progress * simplify * add half type and allow infs in simd exp * unify softmax + quantized, more dispatches to simd quantized mm * add sin/cos, use simd in vector-scalar ops * faster CPU vectorize quant * faster erf/erfinv
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -151,6 +151,78 @@ void _qmm_t(
|
||||
}
|
||||
}
|
||||
|
||||
template <int bits, int S>
|
||||
simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
simd::Simd<uint32_t, S> wi;
|
||||
if constexpr (bits == 4 && S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
wi = simd::Simd<uint32_t, S>(*w);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & bitmask;
|
||||
} else if constexpr (bits == 8 && S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
auto l = simd::Simd<uint32_t, 4>(*w++);
|
||||
auto r = simd::Simd<uint32_t, 4>(*w);
|
||||
wi = simd::Simd<uint32_t, S>(l, r);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & bitmask;
|
||||
} else {
|
||||
// Appease compiler.. but should never get here
|
||||
throw std::runtime_error("Unsupported combination for simd qmm.");
|
||||
}
|
||||
return wi;
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_t_simd(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
constexpr int S = simd::max_size<T>;
|
||||
static_assert(
|
||||
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
|
||||
constexpr int packs_per_simd = S / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
simd::Simd<float, S> acc(0);
|
||||
auto x_local = x;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
|
||||
auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));
|
||||
w_local += packs_per_simd;
|
||||
wf = wf * scale;
|
||||
wf = wf + bias;
|
||||
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
|
||||
acc = acc + x_simd * wf;
|
||||
x_local += S;
|
||||
}
|
||||
}
|
||||
|
||||
*result = T(simd::sum(acc));
|
||||
result++;
|
||||
}
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_dispatch_transpose(
|
||||
T* result,
|
||||
@@ -163,9 +235,14 @@ void _qmm_dispatch_transpose(
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
// the simd size must be a multiple of the number of elements per word
|
||||
if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {
|
||||
_qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
_qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
} else {
|
||||
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
_qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,13 +326,13 @@ void _qmm_dispatch(
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
|
||||
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
|
||||
int batch_size = x.size() / (K * M);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
@@ -384,7 +461,7 @@ void _bs_qmm_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
@@ -411,7 +488,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
|
||||
Reference in New Issue
Block a user