Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)

* start to cleanup/unify accelerate and common back-ends

* more progress

* simplify

* add half type and allow infs in simd exp

* unify softmax + quantized, more dispatches to simd quantized mm

* add sin/cos, use simd in vector-scalar ops

* faster CPU vectorize quant

* faster erf/erfinv
This commit is contained in:
Awni Hannun
2025-01-29 14:34:49 -08:00
committed by GitHub
parent 7064fed1b1
commit 4758c8baa1
47 changed files with 1920 additions and 2640 deletions

View File

@@ -7,6 +7,8 @@
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/simd/simd.h"
namespace mlx::core {
namespace {
@@ -122,16 +124,22 @@ void set_binary_op_output_data(
}
}
struct UseDefaultBinaryOp {};
template <typename T, typename U, typename Op>
struct DefaultVectorScalar {
template <typename Op>
struct VectorScalar {
Op op;
DefaultVectorScalar(Op op_) : op(op_) {}
VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N;
a += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, scalar);
dst++;
@@ -140,14 +148,22 @@ struct DefaultVectorScalar {
}
};
template <typename T, typename U, typename Op>
struct DefaultScalarVector {
template <typename Op>
struct ScalarVector {
Op op;
DefaultScalarVector(Op op_) : op(op_) {}
ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(scalar, *b);
dst++;
@@ -156,13 +172,22 @@ struct DefaultScalarVector {
}
};
template <typename T, typename U, typename Op>
struct DefaultVectorVector {
template <typename Op>
struct VectorVector {
Op op;
DefaultVectorVector(Op op_) : op(op_) {}
VectorVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N;
a += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, *b);
dst++;
@@ -277,21 +302,8 @@ void binary_op_dispatch_dims(
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
@@ -303,19 +315,19 @@ void binary_op(
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
@@ -376,15 +388,39 @@ void binary_op(
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
a,
b,
out,
VectorVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
a,
b,
out,
VectorScalar{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
a,
b,
out,
ScalarVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
@@ -393,134 +429,52 @@ void binary_op(
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
}
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
binary_op<T, T>(a, b, out, op);
}
template <typename... Ops>
void binary(const array& a, const array& b, array& out, Ops... ops) {
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, ops...);
binary_op<bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t>(a, b, out, ops...);
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, ops...);
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, ops...);
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, ops...);
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, ops...);
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, ops...);
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, ops...);
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, ops...);
binary_op<int64_t>(a, b, out, op);
break;
case float16:
binary_op<float16_t>(a, b, out, ops...);
binary_op<float16_t>(a, b, out, op);
break;
case float32:
binary_op<float>(a, b, out, ops...);
binary_op<float>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, ops...);
binary_op<bfloat16_t>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t>(a, b, out, ops...);
binary_op<complex64_t>(a, b, out, op);
break;
}
}