Remove accelerate/ (#1816)

* remove accelerate

* comments

* neon reduction
This commit is contained in:
Awni Hannun
2025-02-01 07:18:26 -08:00
committed by GitHub
parent f5cc1eea72
commit 80c863b972
12 changed files with 311 additions and 451 deletions

View File

@@ -5,6 +5,7 @@
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -67,55 +68,121 @@ const complex64_t Limits<complex64_t>::min =
struct AndReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) &= (b != 0);
bool operator()(bool x, T y) {
return x & (y != 0);
}
void operator()(bool* y, bool x) {
(*y) &= x;
bool operator()(bool x, bool y) {
return x & y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x & (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x & y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::all(x);
};
};
struct OrReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) |= (b != 0);
bool operator()(bool x, T y) {
return x | (y != 0);
}
void operator()(bool* y, bool x) {
(*y) |= x;
bool operator()(bool x, bool y) {
return x | y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x | (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x | y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::any(x);
};
};
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::maximum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::minimum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
};
struct SumReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x + y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return y + x;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::sum(x);
};
};
struct ProdReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x * y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return x * y;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::prod(x);
};
};
@@ -139,18 +206,16 @@ void reduce_dispatch_sum_prod(
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
}
} else {
auto op = [](auto y, auto x) { (*y) *= x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, op);
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 1, op);
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
}
}
}
@@ -195,7 +260,7 @@ void nd_loop(
loop_inner(0, 0);
}
void Reduce::eval(const std::vector<array>& inputs, array& out) {
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {