diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index c7ef4670f..b0b7f0bbb 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -37,13 +37,6 @@ endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) -if(MLX_BUILD_ACCELERATE) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) -elseif(MLX_BUILD_CPU) - target_sources( - mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp) -endif() if(MLX_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) diff --git a/mlx/backend/accelerate/CMakeLists.txt b/mlx/backend/accelerate/CMakeLists.txt deleted file mode 100644 index 96add2ae5..000000000 --- a/mlx/backend/accelerate/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp deleted file mode 100644 index bfee1050f..000000000 --- a/mlx/backend/accelerate/primitives.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -#include - -#include "mlx/allocator.h" -#include "mlx/backend/common/binary.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/unary.h" -#include "mlx/primitives.h" - -namespace mlx::core { - -void Scan::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (reduce_type_ == Scan::Sum && out.dtype() == float32 && - in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - int stride = in.shape(axis_); - int count = in.size() / stride; - const float* input = in.data(); - float* output = out.data(); - float s = 1.0; - if (!reverse_) { - for (int i = 0; i < count; i++) { - vDSP_vrsum(input - 1, 1, &s, output, 1, stride); - input += stride; - output += stride; - } - } else { - for (int i = 0; i < count; i++) { - input += stride - 1; - output += stride - 1; - vDSP_vrsum(input + 1, -1, &s, output, -1, stride); - input++; - output++; - } - } - } else { - eval(inputs, out); - } -} - -} // namespace mlx::core diff --git a/mlx/backend/accelerate/reduce.cpp b/mlx/backend/accelerate/reduce.cpp deleted file mode 100644 index 287243943..000000000 --- a/mlx/backend/accelerate/reduce.cpp +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include - -#include -#include - -#include "mlx/backend/common/reduce.h" -#include "mlx/primitives.h" - -namespace mlx::core { - -namespace { - -template -struct MinReduction { - T operator()(const T& a, const T& b) { - return std::min(a, b); - } - - VT operator()(VT a, VT b) { - return simd_min(a, b); - } -}; - -template -struct MaxReduction { - T operator()(const T& a, const T& b) { - return std::max(a, b); - } - - VT operator()(VT a, VT b) { - return simd_max(a, b); - } -}; - -template -struct SumReduction { - T operator()(const T& a, const T& b) { - return a + b; - } - - VT operator()(VT a, VT b) { - return a + b; - } -}; - -template -struct StridedReduce { - void operator()(const T* x, T* accum, int size, size_t stride) { - Reduction op; - - for (int i = 0; i < size; i++) { - size_t s = stride; - T* a = accum; - while (s >= N) { - *(VT*)a = op((*(VT*)x), (*(VT*)a)); - x += N; - a += N; - s -= N; - } - while (s-- > 0) { - *a = op(*a, *x); - a++; - x++; - } - } - } -}; - -} // namespace - -void Reduce::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - - if (in.dtype() == float32) { - if (reduce_type_ == Reduce::Sum) { - reduction_op( - in, - out, - axes_, - 0, - StridedReduce< - float, - simd_float16, - 16, - SumReduction>(), - [](const auto* x, auto* accum, int size) { - float acc; - vDSP_sve((const float*)x, 1, &acc, size); - (*accum) += acc; - }, - [](auto* accum, auto x) { *accum += x; }); - return; - } else if (reduce_type_ == Reduce::Max) { - reduction_op( - in, - out, - axes_, - -std::numeric_limits::infinity(), - StridedReduce< - float, - simd_float16, - 16, - MaxReduction>(), - [](const auto* x, auto* accum, int size) { - float max; - vDSP_maxv((const float*)x, 1, &max, size); - (*accum) = (*accum < max) ? max : *accum; - }, - [](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; }); - return; - } else if (reduce_type_ == Reduce::Min) { - reduction_op( - in, - out, - axes_, - std::numeric_limits::infinity(), - StridedReduce< - float, - simd_float16, - 16, - MinReduction>(), - [](const auto* x, auto* accum, int size) { - float min; - vDSP_minv((const float*)x, 1, &min, size); - (*accum) = (*accum > min) ? min : *accum; - }, - [](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; }); - return; - } - } - // TODO: Add integer addition and min/max using the templates above and - // simd_int16 and friends. - eval(inputs, out); -} - -} // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp deleted file mode 100644 index 21779c35a..000000000 --- a/mlx/backend/common/default_primitives.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include "mlx/array.h" -#include "mlx/primitives.h" - -#define DEFAULT(primitive) \ - void primitive::eval_cpu(const std::vector& inputs, array& out) { \ - primitive::eval(inputs, out); \ - } - -#define DEFAULT_MULTI(primitive) \ - void primitive::eval_cpu( \ - const std::vector& inputs, std::vector& outputs) { \ - primitive::eval(inputs, outputs); \ - } - -namespace mlx::core { - -DEFAULT(Reduce) -DEFAULT(Scan) - -} // namespace mlx::core diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 332ee7169..71c72e2ea 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -5,6 +5,7 @@ #include #include "mlx/backend/common/reduce.h" +#include "mlx/backend/common/simd/simd.h" #include "mlx/primitives.h" namespace mlx::core { @@ -67,55 +68,121 @@ const complex64_t Limits::min = struct AndReduce { template - void operator()(bool* a, T b) { - (*a) &= (b != 0); + bool operator()(bool x, T y) { + return x & (y != 0); } - void operator()(bool* y, bool x) { - (*y) &= x; + bool operator()(bool x, bool y) { + return x & y; } + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x & (y != 0); + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x & y; + }; + + template + bool operator()(simd::Simd x) { + return simd::all(x); + }; }; struct OrReduce { template - void operator()(bool* a, T b) { - (*a) |= (b != 0); + bool operator()(bool x, T y) { + return x | (y != 0); } - void operator()(bool* y, bool x) { - (*y) |= x; + bool operator()(bool x, bool y) { + return x | y; } + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x | (y != 0); + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x | y; + }; + + template + bool operator()(simd::Simd x) { + return simd::any(x); + }; }; struct MaxReduce { template - std::enable_if_t> operator()(T* y, T x) { - (*y) = (*y > x) ? *y : x; + T operator()(T y, T x) { + return (*this)(simd::Simd(x), simd::Simd(y)).value; }; - template - std::enable_if_t> operator()(T* y, T x) { - if (std::isnan(x)) { - *y = x; - } else { - (*y) = (*y > x) ? *y : x; - } + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return simd::maximum(x, y); + }; + + template + T operator()(simd::Simd x) { + return simd::max(x); }; }; struct MinReduce { template - std::enable_if_t> operator()(T* y, T x) { - (*y) = (*y < x) ? *y : x; + T operator()(T y, T x) { + return (*this)(simd::Simd(x), simd::Simd(y)).value; }; - template - std::enable_if_t> operator()(T* y, T x) { - if (std::isnan(x)) { - *y = x; - } else { - (*y) = (*y < x) ? *y : x; - } + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return simd::minimum(x, y); + }; + + template + T operator()(simd::Simd x) { + return simd::min(x); + }; +}; + +struct SumReduce { + template + U operator()(U y, T x) { + return x + y; + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return y + x; + }; + + template + T operator()(simd::Simd x) { + return simd::sum(x); + }; +}; + +struct ProdReduce { + template + U operator()(U y, T x) { + return x * y; + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x * y; + }; + + template + T operator()(simd::Simd x) { + return simd::prod(x); }; }; @@ -139,18 +206,16 @@ void reduce_dispatch_sum_prod( Reduce::ReduceType rtype, const std::vector& axes) { if (rtype == Reduce::Sum) { - auto op = [](auto y, auto x) { (*y) = (*y) + x; }; if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 0, op); + reduction_op(in, out, axes, 0, SumReduce()); } else { - reduction_op(in, out, axes, 0, op); + reduction_op(in, out, axes, 0, SumReduce()); } } else { - auto op = [](auto y, auto x) { (*y) *= x; }; if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 1, op); + reduction_op(in, out, axes, 1, ProdReduce()); } else { - reduction_op(in, out, axes, 1, op); + reduction_op(in, out, axes, 1, ProdReduce()); } } } @@ -195,7 +260,7 @@ void nd_loop( loop_inner(0, 0); } -void Reduce::eval(const std::vector& inputs, array& out) { +void Reduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; switch (reduce_type_) { diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index 35d8f9e48..b9e44ddc8 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -2,6 +2,7 @@ #pragma once +#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" namespace mlx::core { @@ -60,45 +61,54 @@ std::pair shapes_without_reduction_axes( const std::vector& axes); template -struct DefaultStridedReduce { - Op op; - - DefaultStridedReduce(Op op_) : op(op_) {} - - void operator()(const T* x, U* accumulator, int size, size_t stride) { - for (int i = 0; i < size; i++) { - U* moving_accumulator = accumulator; - for (int j = 0; j < stride; j++) { - op(moving_accumulator, *x); - moving_accumulator++; - x++; - } +void strided_reduce( + const T* x, + U* accumulator, + int size, + size_t stride, + Op op) { + constexpr int N = std::min(simd::max_size, simd::max_size); + for (int i = 0; i < size; i++) { + U* moving_accumulator = accumulator; + auto s = stride; + while (s >= N) { + auto acc = simd::load(moving_accumulator); + auto v = simd::Simd(simd::load(x)); + simd::store(moving_accumulator, op(acc, v)); + moving_accumulator += N; + x += N; + s -= N; } - } -}; - -template -struct DefaultContiguousReduce { - Op op; - - DefaultContiguousReduce(Op op_) : op(op_) {} - - void operator()(const T* x, U* accumulator, int size) { - while (size-- > 0) { - op(accumulator, *x); + while (s-- > 0) { + *moving_accumulator = op(*moving_accumulator, *x); + moving_accumulator++; x++; } } }; -template +template +void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) { + constexpr int N = std::min(simd::max_size, simd::max_size); + simd::Simd accumulator_v(init); + while (size >= N) { + accumulator_v = op(accumulator_v, simd::Simd(simd::load(x))); + x += N; + size -= N; + } + *accumulator = op(*accumulator, op(accumulator_v)); + while (size-- > 0) { + *accumulator = op(*accumulator, *x); + x++; + } +} + +template void reduction_op( const array& x, array& out, const std::vector& axes, U init, - OpS ops, - OpC opc, Op op) { out.set_data(allocator::malloc_or_wait(out.nbytes())); ReductionPlan plan = get_reduction_plan(x, axes); @@ -106,7 +116,7 @@ void reduction_op( if (plan.type == ContiguousAllReduce) { U* out_ptr = out.data(); *out_ptr = init; - opc(x.data(), out_ptr, x.size()); + contiguous_reduce(x.data(), out_ptr, x.size(), op, init); return; } @@ -116,7 +126,7 @@ void reduction_op( U* out_ptr = out.data(); for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) { *out_ptr = init; - opc(x_ptr, out_ptr, reduction_size); + contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init); } return; } @@ -134,7 +144,7 @@ void reduction_op( for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); *out_ptr = init; - opc(x_ptr + offset, out_ptr, reduction_size); + contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init); } } else { for (int i = 0; i < out.size(); i++, out_ptr++) { @@ -142,7 +152,12 @@ void reduction_op( *out_ptr = init; nd_loop( [&](int extra_offset) { - opc(x_ptr + offset + extra_offset, out_ptr, reduction_size); + contiguous_reduce( + x_ptr + offset + extra_offset, + out_ptr, + reduction_size, + op, + init); }, plan.shape, plan.strides); @@ -160,7 +175,7 @@ void reduction_op( U* out_ptr = out.data(); for (int i = 0; i < out.size(); i += reduction_stride) { std::fill_n(out_ptr, reduction_stride, init); - ops(x_ptr, out_ptr, reduction_size, reduction_stride); + strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op); x_ptr += reduction_stride * reduction_size; out_ptr += reduction_stride; } @@ -180,7 +195,8 @@ void reduction_op( for (int i = 0; i < out.size(); i += reduction_stride) { int offset = elem_to_loc(i, shape, strides); std::fill_n(out_ptr, reduction_stride, init); - ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride); + strided_reduce( + x_ptr + offset, out_ptr, reduction_size, reduction_stride, op); out_ptr += reduction_stride; } } else { @@ -189,10 +205,12 @@ void reduction_op( std::fill_n(out_ptr, reduction_stride, init); nd_loop( [&](int extra_offset) { - ops(x_ptr + offset + extra_offset, + strided_reduce( + x_ptr + offset + extra_offset, out_ptr, reduction_size, - reduction_stride); + reduction_stride, + op); }, plan.shape, plan.strides); @@ -210,7 +228,9 @@ void reduction_op( int offset = elem_to_loc(i, shape, strides); U val = init; nd_loop( - [&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); }, + [&](int extra_offset) { + val = op(val, *(x_ptr + offset + extra_offset)); + }, plan.shape, plan.strides); *out_ptr = val; @@ -218,16 +238,4 @@ void reduction_op( } } -template -void reduction_op( - const array& x, - array& out, - const std::vector& axes, - U init, - Op op) { - DefaultStridedReduce ops(op); - DefaultContiguousReduce opc(op); - reduction_op(x, out, axes, init, ops, opc, op); -} - } // namespace mlx::core diff --git a/mlx/backend/common/scan.cpp b/mlx/backend/common/scan.cpp index 153375aef..2430f3172 100644 --- a/mlx/backend/common/scan.cpp +++ b/mlx/backend/common/scan.cpp @@ -3,6 +3,7 @@ #include #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -11,184 +12,178 @@ namespace mlx::core { namespace { template -struct DefaultContiguousScan { - Op op; - U init; - - DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {} - - void operator()( - const T* input, - U* output, - int count, - int stride, - bool reverse, - bool inclusive) { - if (!reverse) { - if (inclusive) { - for (int i = 0; i < count; i++) { - *output = *input; - for (int j = 1; j < stride; j++) { - input++; - output++; - op(output, output - 1, input); - } - output++; +void contiguous_scan( + const T* input, + U* output, + int count, + int stride, + bool reverse, + bool inclusive, + const Op& op, + U init) { + if (!reverse) { + if (inclusive) { + for (int i = 0; i < count; i++) { + *output = *input; + for (int j = 1; j < stride; j++) { input++; - } - } else { - for (int i = 0; i < count; i++) { - *output = init; - for (int j = 1; j < stride; j++) { - op(output + 1, output, input); - input++; - output++; - } output++; - input++; + *output = op(*(output - 1), *input); } + output++; + input++; } } else { - if (inclusive) { - for (int i = 0; i < count; i++) { - output += stride - 1; - input += stride - 1; - *output = *input; - for (int j = 1; j < stride; j++) { - input--; - output--; - op(output, output + 1, input); - } - output += stride; - input += stride; + for (int i = 0; i < count; i++) { + *output = init; + for (int j = 1; j < stride; j++) { + *(output + 1) = op(*output, *input); + input++; + output++; } - } else { - for (int i = 0; i < count; i++) { - output += stride - 1; - input += stride - 1; - *output = init; - for (int j = 1; j < stride; j++) { - op(output - 1, output, input); - input--; - output--; - } - output += stride; - input += stride; + output++; + input++; + } + } + } else { + if (inclusive) { + for (int i = 0; i < count; i++) { + output += stride - 1; + input += stride - 1; + *output = *input; + for (int j = 1; j < stride; j++) { + input--; + output--; + *output = op(*(output + 1), *input); } + output += stride; + input += stride; + } + } else { + for (int i = 0; i < count; i++) { + output += stride - 1; + input += stride - 1; + *output = init; + for (int j = 1; j < stride; j++) { + *(output - 1) = op(*output, *input); + input--; + output--; + } + output += stride; + input += stride; } } } }; template -struct DefaultStridedScan { - Op op; - U init; - - DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {} - - void operator()( - const T* input, - U* output, - int count, - int size, - int stride, - bool reverse, - bool inclusive) { - // TODO: Vectorize the following naive implementation - if (!reverse) { - if (inclusive) { - for (int i = 0; i < count; i++) { - std::copy(input, input + stride, output); - output += stride; - input += stride; - for (int j = 1; j < size; j++) { - for (int k = 0; k < stride; k++) { - op(output, output - stride, input); - output++; - input++; - } - } - } - } else { - for (int i = 0; i < count; i++) { - std::fill(output, output + stride, init); - output += stride; - input += stride; - for (int j = 1; j < size; j++) { - for (int k = 0; k < stride; k++) { - op(output, output - stride, input - stride); - output++; - input++; - } +void strided_scan( + const T* input, + U* output, + int count, + int size, + int stride, + bool reverse, + bool inclusive, + const Op& op, + U init) { + // TODO: Vectorize the following naive implementation + if (!reverse) { + if (inclusive) { + for (int i = 0; i < count; i++) { + std::copy(input, input + stride, output); + output += stride; + input += stride; + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + *output = op(*(output - stride), *input); + output++; + input++; } } } } else { - if (inclusive) { - for (int i = 0; i < count; i++) { - output += (size - 1) * stride; - input += (size - 1) * stride; - std::copy(input, input + stride, output); - for (int j = 1; j < size; j++) { - for (int k = 0; k < stride; k++) { - output--; - input--; - op(output, output + stride, input); - } + for (int i = 0; i < count; i++) { + std::fill(output, output + stride, init); + output += stride; + input += stride; + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + *output = op(*(output - stride), *(input - stride)); + output++; + input++; } - output += size * stride; - input += size * stride; } - } else { - for (int i = 0; i < count; i++) { - output += (size - 1) * stride; - input += (size - 1) * stride; - std::fill(output, output + stride, init); - for (int j = 1; j < size; j++) { - for (int k = 0; k < stride; k++) { - output--; - input--; - op(output, output + stride, input + stride); - } + } + } + } else { + if (inclusive) { + for (int i = 0; i < count; i++) { + output += (size - 1) * stride; + input += (size - 1) * stride; + std::copy(input, input + stride, output); + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + output--; + input--; + *output = op(*(output + stride), *input); } - output += size * stride; - input += size * stride; } + output += size * stride; + input += size * stride; + } + } else { + for (int i = 0; i < count; i++) { + output += (size - 1) * stride; + input += (size - 1) * stride; + std::fill(output, output + stride, init); + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + output--; + input--; + *output = op(*(output + stride), *(input + stride)); + } + } + output += size * stride; + input += size * stride; } } } }; -template +template void scan_op( - OpCS opcs, - OpSS opss, const array& input, array& output, int axis, bool reverse, - bool inclusive) { + bool inclusive, + const Op& op, + U init) { output.set_data(allocator::malloc_or_wait(output.nbytes())); if (input.flags().row_contiguous) { if (input.strides()[axis] == 1) { - opcs( + contiguous_scan( input.data(), output.data(), input.size() / input.shape(axis), input.shape(axis), reverse, - inclusive); + inclusive, + op, + init); } else { - opss( + strided_scan( input.data(), output.data(), input.size() / input.shape(axis) / input.strides()[axis], input.shape(axis), input.strides()[axis], reverse, - inclusive); + inclusive, + op, + init); } } else { throw std::runtime_error("Scan op supports only contiguous inputs"); @@ -205,39 +200,31 @@ void scan_dispatch( bool inclusive) { switch (rtype) { case Scan::Sum: { - auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; }; + auto op = [](U y, T x) { return y + x; }; auto init = static_cast(0); - auto opcs = DefaultContiguousScan(op, init); - auto opss = DefaultStridedScan(op, init); - scan_op(opcs, opss, input, output, axis, reverse, inclusive); + scan_op(input, output, axis, reverse, inclusive, op, init); break; } case Scan::Prod: { - auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); }; + auto op = [](U y, T x) { return y * x; }; auto init = static_cast(1); - auto opcs = DefaultContiguousScan(op, init); - auto opss = DefaultStridedScan(op, init); - scan_op(opcs, opss, input, output, axis, reverse, inclusive); + scan_op(input, output, axis, reverse, inclusive, op, init); break; } case Scan::Min: { - auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; }; + auto op = [](U y, T x) { return x < y ? x : y; }; auto init = (issubdtype(input.dtype(), floating)) ? static_cast(std::numeric_limits::infinity()) : std::numeric_limits::max(); - auto opcs = DefaultContiguousScan(op, init); - auto opss = DefaultStridedScan(op, init); - scan_op(opcs, opss, input, output, axis, reverse, inclusive); + scan_op(input, output, axis, reverse, inclusive, op, init); break; } case Scan::Max: { - auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; + auto op = [](U y, T x) { return x < y ? y : x; }; auto init = (issubdtype(input.dtype(), floating)) ? static_cast(-std::numeric_limits::infinity()) : std::numeric_limits::min(); - auto opcs = DefaultContiguousScan(op, init); - auto opss = DefaultStridedScan(op, init); - scan_op(opcs, opss, input, output, axis, reverse, inclusive); + scan_op(input, output, axis, reverse, inclusive, op, init); break; } } @@ -245,7 +232,7 @@ void scan_dispatch( } // namespace -void Scan::eval(const std::vector& inputs, array& out) { +void Scan::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); // Ensure contiguity diff --git a/mlx/backend/common/simd/accelerate_simd.h b/mlx/backend/common/simd/accelerate_simd.h index 443a8f617..7edb06df5 100644 --- a/mlx/backend/common/simd/accelerate_simd.h +++ b/mlx/backend/common/simd/accelerate_simd.h @@ -267,6 +267,10 @@ Simd fma(Simd x, Simd y, U z) { // Reductions +template +bool all(Simd x) { + return asd::all(x.value); +} template bool any(Simd x) { return asd::any(x.value); @@ -284,6 +288,14 @@ T min(Simd x) { return asd::reduce_min(x.value); } +template +T prod(Simd x) { + auto ptr = (T*)&x; + auto lhs = load(ptr); + auto rhs = load(ptr + N / 2); + return prod(lhs * rhs); +} + } // namespace mlx::core::simd #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/mlx/backend/common/simd/base_simd.h b/mlx/backend/common/simd/base_simd.h index d7e4fdc3d..c1e867811 100644 --- a/mlx/backend/common/simd/base_simd.h +++ b/mlx/backend/common/simd/base_simd.h @@ -246,6 +246,7 @@ Simd fma(Simd x, Simd y, U z) { DEFAULT_REDUCTION(max, T) DEFAULT_REDUCTION(min, T) DEFAULT_REDUCTION(sum, T) +DEFAULT_REDUCTION(prod, T) DEFAULT_REDUCTION(any, bool) DEFAULT_REDUCTION(all, bool) diff --git a/mlx/backend/common/simd/neon_fp16_simd.h b/mlx/backend/common/simd/neon_fp16_simd.h index 923e27776..269ff1305 100644 --- a/mlx/backend/common/simd/neon_fp16_simd.h +++ b/mlx/backend/common/simd/neon_fp16_simd.h @@ -200,5 +200,13 @@ inline float16_t sum(Simd x) { y = vpadd_f16(y, y); return vget_lane_f16(y, 0); } +inline float16_t prod(Simd x) { + auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + auto out = hx[0]; + hx[0] *= hx[1]; + hx[0] *= hx[2]; + hx[0] *= hx[3]; + return hx[0]; +} } // namespace mlx::core::simd diff --git a/mlx/primitives.h b/mlx/primitives.h index 0e2722186..b1f55f9ac 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1691,8 +1691,6 @@ class Reduce : public UnaryPrimitive { private: ReduceType reduce_type_; std::vector axes_; - - void eval(const std::vector& inputs, array& out); }; class Round : public UnaryPrimitive { @@ -1758,8 +1756,6 @@ class Scan : public UnaryPrimitive { int axis_; bool reverse_; bool inclusive_; - - void eval(const std::vector& inputs, array& out); }; class Scatter : public UnaryPrimitive {