mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor common into cpu specific and truly common (#1817)
* refactor * fix extension example * fix no-cpu
This commit is contained in:
@@ -1,377 +1,147 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const U max;
|
||||
static const U min;
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr type max = std::numeric_limits<type>::max(); \
|
||||
static constexpr type min = std::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static const type max; \
|
||||
static const type min; \
|
||||
};
|
||||
|
||||
instantiate_float_limit(float16_t);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(complex64_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr bool max = true;
|
||||
static constexpr bool min = false;
|
||||
};
|
||||
|
||||
const float Limits<float>::max = std::numeric_limits<float>::infinity();
|
||||
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
|
||||
struct AndReduce {
|
||||
template <typename T>
|
||||
bool operator()(bool x, T y) {
|
||||
return x & (y != 0);
|
||||
}
|
||||
|
||||
bool operator()(bool x, bool y) {
|
||||
return x & y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x & (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x & y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::all(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct OrReduce {
|
||||
template <typename T>
|
||||
bool operator()(bool x, T y) {
|
||||
return x | (y != 0);
|
||||
}
|
||||
|
||||
bool operator()(bool x, bool y) {
|
||||
return x | y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x | (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x | y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::any(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::maximum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::minimum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct SumReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x + y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return y + x;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::sum(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ProdReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::prod(x);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
|
||||
// Remove singleton axes from the plan
|
||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
||||
if (shape[i] == 1) {
|
||||
shape.erase(shape.begin() + i);
|
||||
strides.erase(strides.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
}
|
||||
}
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const Shape& shape,
|
||||
const Strides& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, int64_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
bool a_is_zero = a.second == 0;
|
||||
bool b_is_zero = b.second == 0;
|
||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int64_t size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto stride_i = x.strides()[i];
|
||||
auto shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
have_expand = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (stride_i != size && shape_i != 1) {
|
||||
break;
|
||||
}
|
||||
size *= shape_i;
|
||||
}
|
||||
// In the case of an expanded dimension we are being conservative and
|
||||
// require the smallest reduction stride to be smaller than the maximum row
|
||||
// contiguous size. The reason is that we can't easily know if the reduced
|
||||
// axis is before or after an expanded dimension.
|
||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user