Remove code duplication in reduce ops (#793)

* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
nicolov 2024-03-11 18:57:07 +01:00 committed by GitHub
parent 7c441600fe
commit 0ae22b915b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,78 +10,65 @@
namespace mlx::core { namespace mlx::core {
template <typename T, typename VT, int N> namespace {
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) { template <typename T, typename VT>
size_t s = stride; struct MinReduction {
T* a = accum; T operator()(const T& a, const T& b) {
while (s >= N) { return std::min(a, b);
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
} }
// TODO: Add proper templates for the strided reduce algorithm so we don't have VT operator()(VT a, VT b) {
// to write max/min/sum etc. return simd_min(a, b);
template <typename T, typename VT, int N> }
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) { };
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
size_t s = stride; size_t s = stride;
T* a = accum; T* a = accum;
while (s >= N) { while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a)); *(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N; x += N;
a += N; a += N;
s -= N; s -= N;
} }
while (s-- > 0) { while (s-- > 0) {
*a = std::max(*a, *x); *a = op(*a, *x);
a++; a++;
x++; x++;
} }
} }
} }
};
template <typename T, typename VT, int N> } // namespace
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_sum(const T* x, T* accum, int size) {
VT _sum = {0};
while (size >= N) {
_sum += (*(VT*)x);
x += N;
size -= N;
}
T sum = _sum[0];
for (int i = 1; i < N; i++) {
sum += _sum[i];
}
*accum += sum;
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
0, 0,
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_sum<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float acc; float acc;
vDSP_sve((const float*)x, 1, &acc, size); vDSP_sve((const float*)x, 1, &acc, size);
@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
-std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_max<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float max; float max;
vDSP_maxv((const float*)x, 1, &max, size); vDSP_maxv((const float*)x, 1, &max, size);
@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out,
axes_, axes_,
std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) { StridedReduce<
_vectorized_strided_min<float, simd_float16, 16>( float,
(const float*)x, (float*)accum, size, stride); simd_float16,
}, 16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) { [](const auto* x, auto* accum, int size) {
float min; float min;
vDSP_minv((const float*)x, 1, &min, size); vDSP_minv((const float*)x, 1, &min, size);