// Copyright © 2023 Apple Inc. #include #include #include #include "mlx/backend/common/reduce.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" namespace mlx::core { template struct Limits { static const U max; static const U min; }; #define instantiate_default_limit(type) \ template <> \ struct Limits { \ static constexpr type max = std::numeric_limits::max(); \ static constexpr type min = std::numeric_limits::min(); \ }; instantiate_default_limit(uint8_t); instantiate_default_limit(uint16_t); instantiate_default_limit(uint32_t); instantiate_default_limit(uint64_t); instantiate_default_limit(int8_t); instantiate_default_limit(int16_t); instantiate_default_limit(int32_t); instantiate_default_limit(int64_t); #define instantiate_float_limit(type) \ template <> \ struct Limits { \ static const type max; \ static const type min; \ }; instantiate_float_limit(float16_t); instantiate_float_limit(bfloat16_t); instantiate_float_limit(float); instantiate_float_limit(double); instantiate_float_limit(complex64_t); template <> struct Limits { static constexpr bool max = true; static constexpr bool min = false; }; const float Limits::max = std::numeric_limits::infinity(); const float Limits::min = -std::numeric_limits::infinity(); const bfloat16_t Limits::max = std::numeric_limits::infinity(); const bfloat16_t Limits::min = -std::numeric_limits::infinity(); const float16_t Limits::max = std::numeric_limits::infinity(); const float16_t Limits::min = -std::numeric_limits::infinity(); const double Limits::max = std::numeric_limits::infinity(); const double Limits::min = -std::numeric_limits::infinity(); const complex64_t Limits::max = std::numeric_limits::infinity(); const complex64_t Limits::min = -std::numeric_limits::infinity(); template void strided_reduce( const T* x, U* accumulator, int size, size_t stride, Op op) { constexpr int N = std::min(simd::max_size, simd::max_size); for (int i = 0; i < size; i++) { U* moving_accumulator = accumulator; auto s = stride; while (s >= N) { auto acc = simd::load(moving_accumulator); auto v = simd::Simd(simd::load(x)); simd::store(moving_accumulator, op(acc, v)); moving_accumulator += N; x += N; s -= N; } while (s-- > 0) { *moving_accumulator = op(*moving_accumulator, *x); moving_accumulator++; x++; } } }; template void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) { constexpr int N = std::min(simd::max_size, simd::max_size); simd::Simd accumulator_v(init); while (size >= N) { accumulator_v = op(accumulator_v, simd::Simd(simd::load(x))); x += N; size -= N; } *accumulator = op(*accumulator, op(accumulator_v)); while (size-- > 0) { *accumulator = op(*accumulator, *x); x++; } } // Helper for the ndimensional strided loop void nd_loop( std::function callback, const Shape& shape, const Strides& strides) { std::function loop_inner; loop_inner = [&](int dim, int offset) { if (dim < shape.size() - 1) { auto size = shape[dim]; auto stride = strides[dim]; for (int i = 0; i < size; i++) { loop_inner(dim + 1, offset + i * stride); } } else { auto size = shape[dim]; auto stride = strides[dim]; for (int i = 0; i < size; i++) { callback(offset + i * stride); } } }; loop_inner(0, 0); } template void reduction_op( const array& x, array& out, const std::vector& axes, U init) { ReductionPlan plan = get_reduction_plan(x, axes); auto in_ptr = x.data(); auto out_ptr = out.data(); if (plan.type == ContiguousAllReduce) { *out_ptr = init; contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init); return; } if (plan.type == ContiguousReduce && plan.shape.size() == 1) { int reduction_size = plan.shape[0]; for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) { *out_ptr = init; contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init); } return; } if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) { int reduction_size = plan.shape.back(); plan.shape.pop_back(); plan.strides.pop_back(); // Unrolling the following loop (and implementing it in order for // ContiguousReduce) should hold extra performance boost. auto [shape, strides] = shapes_without_reduction_axes(x, axes); if (plan.shape.size() == 0) { for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); *out_ptr = init; contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init); } } else { for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); *out_ptr = init; nd_loop( [&](int extra_offset) { contiguous_reduce( in_ptr + offset + extra_offset, out_ptr, reduction_size, Op{}, init); }, plan.shape, plan.strides); } } return; } if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) { int reduction_size = plan.shape.back(); size_t reduction_stride = plan.strides.back(); plan.shape.pop_back(); plan.strides.pop_back(); for (int i = 0; i < out.size(); i += reduction_stride) { std::fill_n(out_ptr, reduction_stride, init); strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{}); in_ptr += reduction_stride * reduction_size; out_ptr += reduction_stride; } return; } if (plan.type == GeneralStridedReduce || plan.type == ContiguousStridedReduce) { int reduction_size = plan.shape.back(); size_t reduction_stride = plan.strides.back(); plan.shape.pop_back(); plan.strides.pop_back(); auto [shape, strides] = shapes_without_reduction_axes(x, axes); if (plan.shape.size() == 0) { for (int i = 0; i < out.size(); i += reduction_stride) { int offset = elem_to_loc(i, shape, strides); std::fill_n(out_ptr, reduction_stride, init); strided_reduce( in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{}); out_ptr += reduction_stride; } } else { for (int i = 0; i < out.size(); i += reduction_stride) { int offset = elem_to_loc(i, shape, strides); std::fill_n(out_ptr, reduction_stride, init); nd_loop( [&](int extra_offset) { strided_reduce( in_ptr + offset + extra_offset, out_ptr, reduction_size, reduction_stride, Op{}); }, plan.shape, plan.strides); out_ptr += reduction_stride; } } return; } if (plan.type == GeneralReduce) { auto [shape, strides] = shapes_without_reduction_axes(x, axes); for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); U val = init; nd_loop( [&](int extra_offset) { val = Op{}(val, *(in_ptr + offset + extra_offset)); }, plan.shape, plan.strides); *out_ptr = val; } } } struct AndReduce { template bool operator()(bool x, T y) { return x & (y != 0); } bool operator()(bool x, bool y) { return x & y; } template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x & (y != 0); }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x & y; }; template bool operator()(simd::Simd x) { return simd::all(x); }; }; struct OrReduce { template bool operator()(bool x, T y) { return x | (y != 0); } bool operator()(bool x, bool y) { return x | y; } template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x | (y != 0); }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x | y; }; template bool operator()(simd::Simd x) { return simd::any(x); }; }; struct MaxReduce { template T operator()(T y, T x) { return (*this)(simd::Simd(x), simd::Simd(y)).value; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return simd::maximum(x, y); }; template T operator()(simd::Simd x) { return simd::max(x); }; }; struct MinReduce { template T operator()(T y, T x) { return (*this)(simd::Simd(x), simd::Simd(y)).value; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return simd::minimum(x, y); }; template T operator()(simd::Simd x) { return simd::min(x); }; }; struct SumReduce { template U operator()(U y, T x) { return x + y; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return y + x; }; template T operator()(simd::Simd x) { return simd::sum(x); }; }; struct ProdReduce { template U operator()(U y, T x) { return x * y; }; template simd::Simd operator()(simd::Simd y, simd::Simd x) { return x * y; }; template T operator()(simd::Simd x) { return simd::prod(x); }; }; template void reduce_dispatch_and_or( const array& in, array& out, Reduce::ReduceType rtype, const std::vector& axes) { if (rtype == Reduce::And) { reduction_op(in, out, axes, true); } else { reduction_op(in, out, axes, false); } } template void reduce_dispatch_sum_prod( const array& in, array& out, Reduce::ReduceType rtype, const std::vector& axes) { if (rtype == Reduce::Sum) { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { reduction_op(in, out, axes, 0); } else { reduction_op(in, out, axes, 0); } } else { if constexpr (std::is_integral_v && sizeof(InT) <= 4) { reduction_op(in, out, axes, 1); } else { reduction_op(in, out, axes, 1); } } } template void reduce_dispatch_min_max( const array& in, array& out, Reduce::ReduceType rtype, const std::vector& axes) { if (rtype == Reduce::Max) { auto init = Limits::min; reduction_op(in, out, axes, init); } else { auto init = Limits::max; reduction_op(in, out, axes, init); } } void Reduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); encoder.dispatch([in = array::unsafe_weak_copy(in), out = array::unsafe_weak_copy(out), reduce_type_ = reduce_type_, axes_ = axes_]() mutable { switch (reduce_type_) { case Reduce::And: case Reduce::Or: { switch (in.dtype()) { case bool_: case uint8: case int8: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; case int16: case uint16: case float16: case bfloat16: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; case uint32: case int32: case float32: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; case uint64: case int64: case float64: case complex64: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; } break; } case Reduce::Sum: case Reduce::Prod: { switch (in.dtype()) { case bool_: case uint8: case int8: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int16: case uint16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int32: case uint32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int64: case uint64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case bfloat16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case complex64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; } break; } case Reduce::Max: case Reduce::Min: { switch (in.dtype()) { case bool_: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint8: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case uint64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int8: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case int64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case float16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case float32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case float64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case bfloat16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; case complex64: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; } break; } } }); } } // namespace mlx::core