Remove code duplication in reduce ops (#793)

* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
nicolov 2024-03-11 18:57:07 +01:00 committed by GitHub
parent 7c441600fe
commit 0ae22b915b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,78 +10,65 @@
namespace mlx::core { namespace mlx::core {
template <typename T, typename VT, int N> namespace {
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
}
// TODO: Add proper templates for the strided reduce algorithm so we don't have template <typename T, typename VT>
// to write max/min/sum etc. struct MinReduction {
template <typename T, typename VT, int N> T operator()(const T& a, const T& b) {
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) { return std::min(a, b);
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
} }
}
template <typename T, typename VT, int N> VT operator()(VT a, VT b) {
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) { return simd_min(a, b);
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
} }
} };
template <typename T, typename VT, int N> template <typename T, typename VT>
void _vectorized_sum(const T* x, T* accum, int size) { struct MaxReduction {
VT _sum = {0}; T operator()(const T& a, const T& b) {
while (size >= N) { return std::max(a, b);
_sum += (*(VT*)x);
x += N;
size -= N;
} }
T sum = _sum[0];
for (int i = 1; i < N; i++) { VT operator()(VT a, VT b) {
sum += _sum[i]; return simd_max(a, b);
} }
*accum += sum; };
}
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
}
}
};
} // namespace
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
0, 0,
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_sum<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float acc; float acc;
vDSP_sve((const float*)x, 1, &acc, size); vDSP_sve((const float*)x, 1, &acc, size);
@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
-std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_max<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float max; float max;
vDSP_maxv((const float*)x, 1, &max, size); vDSP_maxv((const float*)x, 1, &max, size);
@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_min<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float min; float min;
vDSP_minv((const float*)x, 1, &min, size); vDSP_minv((const float*)x, 1, &min, size);