Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)

* start to cleanup/unify accelerate and common back-ends

* more progress

* simplify

* add half type and allow infs in simd exp

* unify softmax + quantized, more dispatches to simd quantized mm

* add sin/cos, use simd in vector-scalar ops

* faster CPU vectorize quant

* faster erf/erfinv
This commit is contained in:
Awni Hannun
2025-01-29 14:34:49 -08:00
committed by GitHub
parent 7064fed1b1
commit 4758c8baa1
47 changed files with 1920 additions and 2640 deletions

View File

@@ -3,7 +3,7 @@
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -151,6 +151,78 @@ void _qmm_t(
}
}
template <int bits, int S>
simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
constexpr int bitmask = (1 << bits) - 1;
simd::Simd<uint32_t, S> wi;
if constexpr (bits == 4 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & bitmask;
} else if constexpr (bits == 8 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto l = simd::Simd<uint32_t, 4>(*w++);
auto r = simd::Simd<uint32_t, 4>(*w);
wi = simd::Simd<uint32_t, S>(l, r);
wi = wi >> shifts;
wi = wi & bitmask;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
return wi;
}
template <typename T, int bits, int group_size>
void _qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));
w_local += packs_per_simd;
wf = wf * scale;
wf = wf + bias;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
acc = acc + x_simd * wf;
x_local += S;
}
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
@@ -163,9 +235,14 @@ void _qmm_dispatch_transpose(
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
// the simd size must be a multiple of the number of elements per word
if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {
_qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
_qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
_qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
@@ -249,13 +326,13 @@ void _qmm_dispatch(
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
int batch_size = x.size() / (K * M);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
@@ -384,7 +461,7 @@ void _bs_qmm_dispatch(
} // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
@@ -411,7 +488,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];