mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops * Remove the unnecessary lambda --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
7c441600fe
commit
0ae22b915b
@ -10,78 +10,65 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename T, typename VT, int N>
|
namespace {
|
||||||
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
size_t s = stride;
|
|
||||||
T* a = accum;
|
|
||||||
while (s >= N) {
|
|
||||||
VT val = (*(VT*)x);
|
|
||||||
*(VT*)a += val;
|
|
||||||
x += N;
|
|
||||||
a += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
while (s-- > 0) {
|
|
||||||
*a++ += *x++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Add proper templates for the strided reduce algorithm so we don't have
|
template <typename T, typename VT>
|
||||||
// to write max/min/sum etc.
|
struct MinReduction {
|
||||||
template <typename T, typename VT, int N>
|
T operator()(const T& a, const T& b) {
|
||||||
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
|
return std::min(a, b);
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
size_t s = stride;
|
|
||||||
T* a = accum;
|
|
||||||
while (s >= N) {
|
|
||||||
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
|
|
||||||
x += N;
|
|
||||||
a += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
while (s-- > 0) {
|
|
||||||
*a = std::max(*a, *x);
|
|
||||||
a++;
|
|
||||||
x++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename VT, int N>
|
VT operator()(VT a, VT b) {
|
||||||
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
|
return simd_min(a, b);
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
size_t s = stride;
|
|
||||||
T* a = accum;
|
|
||||||
while (s >= N) {
|
|
||||||
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
|
|
||||||
x += N;
|
|
||||||
a += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
while (s-- > 0) {
|
|
||||||
*a = std::min(*a, *x);
|
|
||||||
a++;
|
|
||||||
x++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
template <typename T, typename VT, int N>
|
template <typename T, typename VT>
|
||||||
void _vectorized_sum(const T* x, T* accum, int size) {
|
struct MaxReduction {
|
||||||
VT _sum = {0};
|
T operator()(const T& a, const T& b) {
|
||||||
while (size >= N) {
|
return std::max(a, b);
|
||||||
_sum += (*(VT*)x);
|
|
||||||
x += N;
|
|
||||||
size -= N;
|
|
||||||
}
|
}
|
||||||
T sum = _sum[0];
|
|
||||||
for (int i = 1; i < N; i++) {
|
VT operator()(VT a, VT b) {
|
||||||
sum += _sum[i];
|
return simd_max(a, b);
|
||||||
}
|
}
|
||||||
*accum += sum;
|
};
|
||||||
}
|
|
||||||
|
template <typename T, typename VT>
|
||||||
|
struct SumReduction {
|
||||||
|
T operator()(const T& a, const T& b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT operator()(VT a, VT b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename VT, int N, typename Reduction>
|
||||||
|
struct StridedReduce {
|
||||||
|
void operator()(const T* x, T* accum, int size, size_t stride) {
|
||||||
|
Reduction op;
|
||||||
|
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
size_t s = stride;
|
||||||
|
T* a = accum;
|
||||||
|
while (s >= N) {
|
||||||
|
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
||||||
|
x += N;
|
||||||
|
a += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
while (s-- > 0) {
|
||||||
|
*a = op(*a, *x);
|
||||||
|
a++;
|
||||||
|
x++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out,
|
out,
|
||||||
axes_,
|
axes_,
|
||||||
0,
|
0,
|
||||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
StridedReduce<
|
||||||
_vectorized_strided_sum<float, simd_float16, 16>(
|
float,
|
||||||
(const float*)x, (float*)accum, size, stride);
|
simd_float16,
|
||||||
},
|
16,
|
||||||
|
SumReduction<float, simd_float16>>(),
|
||||||
[](const auto* x, auto* accum, int size) {
|
[](const auto* x, auto* accum, int size) {
|
||||||
float acc;
|
float acc;
|
||||||
vDSP_sve((const float*)x, 1, &acc, size);
|
vDSP_sve((const float*)x, 1, &acc, size);
|
||||||
@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out,
|
out,
|
||||||
axes_,
|
axes_,
|
||||||
-std::numeric_limits<float>::infinity(),
|
-std::numeric_limits<float>::infinity(),
|
||||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
StridedReduce<
|
||||||
_vectorized_strided_max<float, simd_float16, 16>(
|
float,
|
||||||
(const float*)x, (float*)accum, size, stride);
|
simd_float16,
|
||||||
},
|
16,
|
||||||
|
MaxReduction<float, simd_float16>>(),
|
||||||
[](const auto* x, auto* accum, int size) {
|
[](const auto* x, auto* accum, int size) {
|
||||||
float max;
|
float max;
|
||||||
vDSP_maxv((const float*)x, 1, &max, size);
|
vDSP_maxv((const float*)x, 1, &max, size);
|
||||||
@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out,
|
out,
|
||||||
axes_,
|
axes_,
|
||||||
std::numeric_limits<float>::infinity(),
|
std::numeric_limits<float>::infinity(),
|
||||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
StridedReduce<
|
||||||
_vectorized_strided_min<float, simd_float16, 16>(
|
float,
|
||||||
(const float*)x, (float*)accum, size, stride);
|
simd_float16,
|
||||||
},
|
16,
|
||||||
|
MinReduction<float, simd_float16>>(),
|
||||||
[](const auto* x, auto* accum, int size) {
|
[](const auto* x, auto* accum, int size) {
|
||||||
float min;
|
float min;
|
||||||
vDSP_minv((const float*)x, 1, &min, size);
|
vDSP_minv((const float*)x, 1, &min, size);
|
||||||
|
Loading…
Reference in New Issue
Block a user