From 0ae22b915b07b8b9357dad44892b972d8f87daa1 Mon Sep 17 00:00:00 2001 From: nicolov Date: Mon, 11 Mar 2024 18:57:07 +0100 Subject: [PATCH] Remove code duplication in reduce ops (#793) * Remove code duplication in reduce ops * Remove the unnecessary lambda --------- Co-authored-by: Angelos Katharopoulos --- mlx/backend/accelerate/reduce.cpp | 144 ++++++++++++++---------------- 1 file changed, 67 insertions(+), 77 deletions(-) diff --git a/mlx/backend/accelerate/reduce.cpp b/mlx/backend/accelerate/reduce.cpp index db2b8eba2..15a5d83b9 100644 --- a/mlx/backend/accelerate/reduce.cpp +++ b/mlx/backend/accelerate/reduce.cpp @@ -10,78 +10,65 @@ namespace mlx::core { -template -void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) { - for (int i = 0; i < size; i++) { - size_t s = stride; - T* a = accum; - while (s >= N) { - VT val = (*(VT*)x); - *(VT*)a += val; - x += N; - a += N; - s -= N; - } - while (s-- > 0) { - *a++ += *x++; - } - } -} +namespace { -// TODO: Add proper templates for the strided reduce algorithm so we don't have -// to write max/min/sum etc. -template -void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) { - for (int i = 0; i < size; i++) { - size_t s = stride; - T* a = accum; - while (s >= N) { - *(VT*)a = simd_max((*(VT*)x), (*(VT*)a)); - x += N; - a += N; - s -= N; - } - while (s-- > 0) { - *a = std::max(*a, *x); - a++; - x++; - } +template +struct MinReduction { + T operator()(const T& a, const T& b) { + return std::min(a, b); } -} -template -void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) { - for (int i = 0; i < size; i++) { - size_t s = stride; - T* a = accum; - while (s >= N) { - *(VT*)a = simd_min((*(VT*)x), (*(VT*)a)); - x += N; - a += N; - s -= N; - } - while (s-- > 0) { - *a = std::min(*a, *x); - a++; - x++; - } + VT operator()(VT a, VT b) { + return simd_min(a, b); } -} +}; -template -void _vectorized_sum(const T* x, T* accum, int size) { - VT _sum = {0}; - while (size >= N) { - _sum += (*(VT*)x); - x += N; - size -= N; +template +struct MaxReduction { + T operator()(const T& a, const T& b) { + return std::max(a, b); } - T sum = _sum[0]; - for (int i = 1; i < N; i++) { - sum += _sum[i]; + + VT operator()(VT a, VT b) { + return simd_max(a, b); } - *accum += sum; -} +}; + +template +struct SumReduction { + T operator()(const T& a, const T& b) { + return a + b; + } + + VT operator()(VT a, VT b) { + return a + b; + } +}; + +template +struct StridedReduce { + void operator()(const T* x, T* accum, int size, size_t stride) { + Reduction op; + + for (int i = 0; i < size; i++) { + size_t s = stride; + T* a = accum; + while (s >= N) { + *(VT*)a = op((*(VT*)x), (*(VT*)a)); + x += N; + a += N; + s -= N; + } + while (s-- > 0) { + *a = op(*a, *x); + a++; + x++; + } + } + } +}; + +} // namespace void Reduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); @@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { out, axes_, 0, - [](const auto* x, auto* accum, int size, size_t stride) { - _vectorized_strided_sum( - (const float*)x, (float*)accum, size, stride); - }, + StridedReduce< + float, + simd_float16, + 16, + SumReduction>(), [](const auto* x, auto* accum, int size) { float acc; vDSP_sve((const float*)x, 1, &acc, size); @@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { out, axes_, -std::numeric_limits::infinity(), - [](const auto* x, auto* accum, int size, size_t stride) { - _vectorized_strided_max( - (const float*)x, (float*)accum, size, stride); - }, + StridedReduce< + float, + simd_float16, + 16, + MaxReduction>(), [](const auto* x, auto* accum, int size) { float max; vDSP_maxv((const float*)x, 1, &max, size); @@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { out, axes_, std::numeric_limits::infinity(), - [](const auto* x, auto* accum, int size, size_t stride) { - _vectorized_strided_min( - (const float*)x, (float*)accum, size, stride); - }, + StridedReduce< + float, + simd_float16, + 16, + MinReduction>(), [](const auto* x, auto* accum, int size) { float min; vDSP_minv((const float*)x, 1, &min, size);