mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops * Remove the unnecessary lambda --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -10,78 +10,65 @@ | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| template <typename T, typename VT, int N> | ||||
| void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) { | ||||
|   for (int i = 0; i < size; i++) { | ||||
|     size_t s = stride; | ||||
|     T* a = accum; | ||||
|     while (s >= N) { | ||||
|       VT val = (*(VT*)x); | ||||
|       *(VT*)a += val; | ||||
|       x += N; | ||||
|       a += N; | ||||
|       s -= N; | ||||
|     } | ||||
|     while (s-- > 0) { | ||||
|       *a++ += *x++; | ||||
|     } | ||||
|   } | ||||
| } | ||||
| namespace { | ||||
|  | ||||
| // TODO: Add proper templates for the strided reduce algorithm so we don't have | ||||
| // to write max/min/sum etc. | ||||
| template <typename T, typename VT, int N> | ||||
| void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) { | ||||
|   for (int i = 0; i < size; i++) { | ||||
|     size_t s = stride; | ||||
|     T* a = accum; | ||||
|     while (s >= N) { | ||||
|       *(VT*)a = simd_max((*(VT*)x), (*(VT*)a)); | ||||
|       x += N; | ||||
|       a += N; | ||||
|       s -= N; | ||||
|     } | ||||
|     while (s-- > 0) { | ||||
|       *a = std::max(*a, *x); | ||||
|       a++; | ||||
|       x++; | ||||
|     } | ||||
| template <typename T, typename VT> | ||||
| struct MinReduction { | ||||
|   T operator()(const T& a, const T& b) { | ||||
|     return std::min(a, b); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename VT, int N> | ||||
| void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) { | ||||
|   for (int i = 0; i < size; i++) { | ||||
|     size_t s = stride; | ||||
|     T* a = accum; | ||||
|     while (s >= N) { | ||||
|       *(VT*)a = simd_min((*(VT*)x), (*(VT*)a)); | ||||
|       x += N; | ||||
|       a += N; | ||||
|       s -= N; | ||||
|     } | ||||
|     while (s-- > 0) { | ||||
|       *a = std::min(*a, *x); | ||||
|       a++; | ||||
|       x++; | ||||
|     } | ||||
|   VT operator()(VT a, VT b) { | ||||
|     return simd_min(a, b); | ||||
|   } | ||||
| } | ||||
| }; | ||||
|  | ||||
| template <typename T, typename VT, int N> | ||||
| void _vectorized_sum(const T* x, T* accum, int size) { | ||||
|   VT _sum = {0}; | ||||
|   while (size >= N) { | ||||
|     _sum += (*(VT*)x); | ||||
|     x += N; | ||||
|     size -= N; | ||||
| template <typename T, typename VT> | ||||
| struct MaxReduction { | ||||
|   T operator()(const T& a, const T& b) { | ||||
|     return std::max(a, b); | ||||
|   } | ||||
|   T sum = _sum[0]; | ||||
|   for (int i = 1; i < N; i++) { | ||||
|     sum += _sum[i]; | ||||
|  | ||||
|   VT operator()(VT a, VT b) { | ||||
|     return simd_max(a, b); | ||||
|   } | ||||
|   *accum += sum; | ||||
| } | ||||
| }; | ||||
|  | ||||
| template <typename T, typename VT> | ||||
| struct SumReduction { | ||||
|   T operator()(const T& a, const T& b) { | ||||
|     return a + b; | ||||
|   } | ||||
|  | ||||
|   VT operator()(VT a, VT b) { | ||||
|     return a + b; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename T, typename VT, int N, typename Reduction> | ||||
| struct StridedReduce { | ||||
|   void operator()(const T* x, T* accum, int size, size_t stride) { | ||||
|     Reduction op; | ||||
|  | ||||
|     for (int i = 0; i < size; i++) { | ||||
|       size_t s = stride; | ||||
|       T* a = accum; | ||||
|       while (s >= N) { | ||||
|         *(VT*)a = op((*(VT*)x), (*(VT*)a)); | ||||
|         x += N; | ||||
|         a += N; | ||||
|         s -= N; | ||||
|       } | ||||
|       while (s-- > 0) { | ||||
|         *a = op(*a, *x); | ||||
|         a++; | ||||
|         x++; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() == 1); | ||||
| @@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { | ||||
|           out, | ||||
|           axes_, | ||||
|           0, | ||||
|           [](const auto* x, auto* accum, int size, size_t stride) { | ||||
|             _vectorized_strided_sum<float, simd_float16, 16>( | ||||
|                 (const float*)x, (float*)accum, size, stride); | ||||
|           }, | ||||
|           StridedReduce< | ||||
|               float, | ||||
|               simd_float16, | ||||
|               16, | ||||
|               SumReduction<float, simd_float16>>(), | ||||
|           [](const auto* x, auto* accum, int size) { | ||||
|             float acc; | ||||
|             vDSP_sve((const float*)x, 1, &acc, size); | ||||
| @@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { | ||||
|           out, | ||||
|           axes_, | ||||
|           -std::numeric_limits<float>::infinity(), | ||||
|           [](const auto* x, auto* accum, int size, size_t stride) { | ||||
|             _vectorized_strided_max<float, simd_float16, 16>( | ||||
|                 (const float*)x, (float*)accum, size, stride); | ||||
|           }, | ||||
|           StridedReduce< | ||||
|               float, | ||||
|               simd_float16, | ||||
|               16, | ||||
|               MaxReduction<float, simd_float16>>(), | ||||
|           [](const auto* x, auto* accum, int size) { | ||||
|             float max; | ||||
|             vDSP_maxv((const float*)x, 1, &max, size); | ||||
| @@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { | ||||
|           out, | ||||
|           axes_, | ||||
|           std::numeric_limits<float>::infinity(), | ||||
|           [](const auto* x, auto* accum, int size, size_t stride) { | ||||
|             _vectorized_strided_min<float, simd_float16, 16>( | ||||
|                 (const float*)x, (float*)accum, size, stride); | ||||
|           }, | ||||
|           StridedReduce< | ||||
|               float, | ||||
|               simd_float16, | ||||
|               16, | ||||
|               MinReduction<float, simd_float16>>(), | ||||
|           [](const auto* x, auto* accum, int size) { | ||||
|             float min; | ||||
|             vDSP_minv((const float*)x, 1, &min, size); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 nicolov
					nicolov