mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove accelerate/ (#1816)
* remove accelerate * comments * neon reduction
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -67,55 +68,121 @@ const complex64_t Limits<complex64_t>::min =
|
||||
|
||||
struct AndReduce {
|
||||
template <typename T>
|
||||
void operator()(bool* a, T b) {
|
||||
(*a) &= (b != 0);
|
||||
bool operator()(bool x, T y) {
|
||||
return x & (y != 0);
|
||||
}
|
||||
|
||||
void operator()(bool* y, bool x) {
|
||||
(*y) &= x;
|
||||
bool operator()(bool x, bool y) {
|
||||
return x & y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x & (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x & y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::all(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct OrReduce {
|
||||
template <typename T>
|
||||
void operator()(bool* a, T b) {
|
||||
(*a) |= (b != 0);
|
||||
bool operator()(bool x, T y) {
|
||||
return x | (y != 0);
|
||||
}
|
||||
|
||||
void operator()(bool* y, bool x) {
|
||||
(*y) |= x;
|
||||
bool operator()(bool x, bool y) {
|
||||
return x | y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x | (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x | y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::any(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
}
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::maximum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
}
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::minimum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct SumReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x + y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return y + x;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::sum(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ProdReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::prod(x);
|
||||
};
|
||||
};
|
||||
|
||||
@@ -139,18 +206,16 @@ void reduce_dispatch_sum_prod(
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
}
|
||||
} else {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, op);
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,7 +260,7 @@ void nd_loop(
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
|
||||
Reference in New Issue
Block a user