Remove accelerate/ (#1816)

* remove accelerate

* comments

* neon reduction
This commit is contained in:
Awni Hannun 2025-02-01 07:18:26 -08:00 committed by GitHub
parent f5cc1eea72
commit 80c863b972
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 311 additions and 451 deletions

View File

@ -37,13 +37,6 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
endif()
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)

View File

@ -1,2 +0,0 @@
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp)

View File

@ -1,47 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include <Accelerate/Accelerate.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/unary.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int stride = in.shape(axis_);
int count = in.size() / stride;
const float* input = in.data<float>();
float* output = out.data<float>();
float s = 1.0;
if (!reverse_) {
for (int i = 0; i < count; i++) {
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
input += stride;
output += stride;
}
} else {
for (int i = 0; i < count; i++) {
input += stride - 1;
output += stride - 1;
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
input++;
output++;
}
}
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@ -1,139 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
};
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
}
}
};
} // namespace
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32) {
if (reduce_type_ == Reduce::Sum) {
reduction_op<float, float>(
in,
out,
axes_,
0,
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
(*accum) += acc;
},
[](auto* accum, auto x) { *accum += x; });
return;
} else if (reduce_type_ == Reduce::Max) {
reduction_op<float, float>(
in,
out,
axes_,
-std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
(*accum) = (*accum < max) ? max : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
return;
} else if (reduce_type_ == Reduce::Min) {
reduction_op<float, float>(
in,
out,
axes_,
std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);
(*accum) = (*accum > min) ? min : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
return;
}
}
// TODO: Add integer addition and min/max using the templates above and
// simd_int16 and friends.
eval(inputs, out);
}
} // namespace mlx::core

View File

@ -1,22 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
DEFAULT(Reduce)
DEFAULT(Scan)
} // namespace mlx::core

View File

@ -5,6 +5,7 @@
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/primitives.h"
namespace mlx::core {
@ -67,55 +68,121 @@ const complex64_t Limits<complex64_t>::min =
struct AndReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) &= (b != 0);
bool operator()(bool x, T y) {
return x & (y != 0);
}
void operator()(bool* y, bool x) {
(*y) &= x;
bool operator()(bool x, bool y) {
return x & y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x & (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x & y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::all(x);
};
};
struct OrReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) |= (b != 0);
bool operator()(bool x, T y) {
return x | (y != 0);
}
void operator()(bool* y, bool x) {
(*y) |= x;
bool operator()(bool x, bool y) {
return x | y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x | (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x | y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::any(x);
};
};
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::maximum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::minimum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
};
struct SumReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x + y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return y + x;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::sum(x);
};
};
struct ProdReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x * y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return x * y;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::prod(x);
};
};
@ -139,18 +206,16 @@ void reduce_dispatch_sum_prod(
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
}
} else {
auto op = [](auto y, auto x) { (*y) *= x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, op);
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 1, op);
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
}
}
}
@ -195,7 +260,7 @@ void nd_loop(
loop_inner(0, 0);
}
void Reduce::eval(const std::vector<array>& inputs, array& out) {
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {

View File

@ -2,6 +2,7 @@
#pragma once
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@ -60,45 +61,54 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
const std::vector<int>& axes);
template <typename T, typename U, typename Op>
struct DefaultStridedReduce {
Op op;
DefaultStridedReduce(Op op_) : op(op_) {}
void operator()(const T* x, U* accumulator, int size, size_t stride) {
void strided_reduce(
const T* x,
U* accumulator,
int size,
size_t stride,
Op op) {
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
for (int i = 0; i < size; i++) {
U* moving_accumulator = accumulator;
for (int j = 0; j < stride; j++) {
op(moving_accumulator, *x);
auto s = stride;
while (s >= N) {
auto acc = simd::load<U, N>(moving_accumulator);
auto v = simd::Simd<U, N>(simd::load<T, N>(x));
simd::store<U, N>(moving_accumulator, op(acc, v));
moving_accumulator += N;
x += N;
s -= N;
}
while (s-- > 0) {
*moving_accumulator = op(*moving_accumulator, *x);
moving_accumulator++;
x++;
}
}
}
};
template <typename T, typename U, typename Op>
struct DefaultContiguousReduce {
Op op;
DefaultContiguousReduce(Op op_) : op(op_) {}
void operator()(const T* x, U* accumulator, int size) {
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
simd::Simd<U, N> accumulator_v(init);
while (size >= N) {
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
x += N;
size -= N;
}
*accumulator = op(*accumulator, op(accumulator_v));
while (size-- > 0) {
op(accumulator, *x);
*accumulator = op(*accumulator, *x);
x++;
}
}
};
template <typename T, typename U, typename OpS, typename OpC, typename Op>
template <typename T, typename U, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
OpS ops,
OpC opc,
Op op) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
@ -106,7 +116,7 @@ void reduction_op(
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
opc(x.data<T>(), out_ptr, x.size());
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
return;
}
@ -116,7 +126,7 @@ void reduction_op(
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init;
opc(x_ptr, out_ptr, reduction_size);
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
}
return;
}
@ -134,7 +144,7 @@ void reduction_op(
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
opc(x_ptr + offset, out_ptr, reduction_size);
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
@ -142,7 +152,12 @@ void reduction_op(
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
contiguous_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
op,
init);
},
plan.shape,
plan.strides);
@ -160,7 +175,7 @@ void reduction_op(
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
ops(x_ptr, out_ptr, reduction_size, reduction_stride);
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
x_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
@ -180,7 +195,8 @@ void reduction_op(
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
strided_reduce(
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
out_ptr += reduction_stride;
}
} else {
@ -189,10 +205,12 @@ void reduction_op(
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
ops(x_ptr + offset + extra_offset,
strided_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride);
reduction_stride,
op);
},
plan.shape,
plan.strides);
@ -210,7 +228,9 @@ void reduction_op(
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
[&](int extra_offset) {
val = op(val, *(x_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
@ -218,16 +238,4 @@ void reduction_op(
}
}
template <typename T, typename U, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Op op) {
DefaultStridedReduce<T, U, Op> ops(op);
DefaultContiguousReduce<T, U, Op> opc(op);
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
}
} // namespace mlx::core

View File

@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@ -11,19 +12,15 @@ namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
struct DefaultContiguousScan {
Op op;
U init;
DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {}
void operator()(
void contiguous_scan(
const T* input,
U* output,
int count,
int stride,
bool reverse,
bool inclusive) {
bool inclusive,
const Op& op,
U init) {
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
@ -31,7 +28,7 @@ struct DefaultContiguousScan {
for (int j = 1; j < stride; j++) {
input++;
output++;
op(output, output - 1, input);
*output = op(*(output - 1), *input);
}
output++;
input++;
@ -40,7 +37,7 @@ struct DefaultContiguousScan {
for (int i = 0; i < count; i++) {
*output = init;
for (int j = 1; j < stride; j++) {
op(output + 1, output, input);
*(output + 1) = op(*output, *input);
input++;
output++;
}
@ -57,7 +54,7 @@ struct DefaultContiguousScan {
for (int j = 1; j < stride; j++) {
input--;
output--;
op(output, output + 1, input);
*output = op(*(output + 1), *input);
}
output += stride;
input += stride;
@ -68,7 +65,7 @@ struct DefaultContiguousScan {
input += stride - 1;
*output = init;
for (int j = 1; j < stride; j++) {
op(output - 1, output, input);
*(output - 1) = op(*output, *input);
input--;
output--;
}
@ -77,24 +74,19 @@ struct DefaultContiguousScan {
}
}
}
}
};
template <typename T, typename U, typename Op>
struct DefaultStridedScan {
Op op;
U init;
DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {}
void operator()(
void strided_scan(
const T* input,
U* output,
int count,
int size,
int stride,
bool reverse,
bool inclusive) {
bool inclusive,
const Op& op,
U init) {
// TODO: Vectorize the following naive implementation
if (!reverse) {
if (inclusive) {
@ -104,7 +96,7 @@ struct DefaultStridedScan {
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
op(output, output - stride, input);
*output = op(*(output - stride), *input);
output++;
input++;
}
@ -117,7 +109,7 @@ struct DefaultStridedScan {
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
op(output, output - stride, input - stride);
*output = op(*(output - stride), *(input - stride));
output++;
input++;
}
@ -134,7 +126,7 @@ struct DefaultStridedScan {
for (int k = 0; k < stride; k++) {
output--;
input--;
op(output, output + stride, input);
*output = op(*(output + stride), *input);
}
}
output += size * stride;
@ -149,7 +141,7 @@ struct DefaultStridedScan {
for (int k = 0; k < stride; k++) {
output--;
input--;
op(output, output + stride, input + stride);
*output = op(*(output + stride), *(input + stride));
}
}
output += size * stride;
@ -157,38 +149,41 @@ struct DefaultStridedScan {
}
}
}
}
};
template <typename T, typename U, typename OpCS, typename OpSS>
template <typename T, typename U, typename Op>
void scan_op(
OpCS opcs,
OpSS opss,
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive) {
bool inclusive,
const Op& op,
U init) {
output.set_data(allocator::malloc_or_wait(output.nbytes()));
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
opcs(
contiguous_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis),
input.shape(axis),
reverse,
inclusive);
inclusive,
op,
init);
} else {
opss(
strided_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis],
input.shape(axis),
input.strides()[axis],
reverse,
inclusive);
inclusive,
op,
init);
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
@ -205,39 +200,31 @@ void scan_dispatch(
bool inclusive) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; };
auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0);
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Prod: {
auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); };
auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1);
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto op = [](U y, T x) { return x < y ? x : y; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto op = [](U y, T x) { return x < y ? y : x; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
}
@ -245,7 +232,7 @@ void scan_dispatch(
} // namespace
void Scan::eval(const std::vector<array>& inputs, array& out) {
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Ensure contiguity

View File

@ -267,6 +267,10 @@ Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
// Reductions
template <typename T, int N>
bool all(Simd<T, N> x) {
return asd::all(x.value);
}
template <typename T, int N>
bool any(Simd<T, N> x) {
return asd::any(x.value);
@ -284,6 +288,14 @@ T min(Simd<T, N> x) {
return asd::reduce_min(x.value);
}
template <typename T, int N>
T prod(Simd<T, N> x) {
auto ptr = (T*)&x;
auto lhs = load<T, N / 2>(ptr);
auto rhs = load<T, N / 2>(ptr + N / 2);
return prod(lhs * rhs);
}
} // namespace mlx::core::simd
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

View File

@ -246,6 +246,7 @@ Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
DEFAULT_REDUCTION(max, T)
DEFAULT_REDUCTION(min, T)
DEFAULT_REDUCTION(sum, T)
DEFAULT_REDUCTION(prod, T)
DEFAULT_REDUCTION(any, bool)
DEFAULT_REDUCTION(all, bool)

View File

@ -200,5 +200,13 @@ inline float16_t sum(Simd<float16_t, N> x) {
y = vpadd_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t prod(Simd<float16_t, N> x) {
auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
auto out = hx[0];
hx[0] *= hx[1];
hx[0] *= hx[2];
hx[0] *= hx[3];
return hx[0];
}
} // namespace mlx::core::simd

View File

@ -1691,8 +1691,6 @@ class Reduce : public UnaryPrimitive {
private:
ReduceType reduce_type_;
std::vector<int> axes_;
void eval(const std::vector<array>& inputs, array& out);
};
class Round : public UnaryPrimitive {
@ -1758,8 +1756,6 @@ class Scan : public UnaryPrimitive {
int axis_;
bool reverse_;
bool inclusive_;
void eval(const std::vector<array>& inputs, array& out);
};
class Scatter : public UnaryPrimitive {