mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
Remove accelerate/ (#1816)
* remove accelerate * comments * neon reduction
This commit is contained in:
parent
f5cc1eea72
commit
80c863b972
@ -37,13 +37,6 @@ endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
|
@ -1,2 +0,0 @@
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp)
|
@ -1,47 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
||||
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int stride = in.shape(axis_);
|
||||
int count = in.size() / stride;
|
||||
const float* input = in.data<float>();
|
||||
float* output = out.data<float>();
|
||||
float s = 1.0;
|
||||
if (!reverse_) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
||||
input += stride;
|
||||
output += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
input += stride - 1;
|
||||
output += stride - 1;
|
||||
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,139 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MinReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_min(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MaxReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::max(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct SumReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, int N, typename Reduction>
|
||||
struct StridedReduce {
|
||||
void operator()(const T* x, T* accum, int size, size_t stride) {
|
||||
Reduction op;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = op(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.dtype() == float32) {
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
0,
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
SumReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float acc;
|
||||
vDSP_sve((const float*)x, 1, &acc, size);
|
||||
(*accum) += acc;
|
||||
},
|
||||
[](auto* accum, auto x) { *accum += x; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Max) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MaxReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float max;
|
||||
vDSP_maxv((const float*)x, 1, &max, size);
|
||||
(*accum) = (*accum < max) ? max : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Min) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MinReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float min;
|
||||
vDSP_minv((const float*)x, 1, &min, size);
|
||||
(*accum) = (*accum > min) ? min : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
||||
return;
|
||||
}
|
||||
}
|
||||
// TODO: Add integer addition and min/max using the templates above and
|
||||
// simd_int16 and friends.
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,22 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Scan)
|
||||
|
||||
} // namespace mlx::core
|
@ -5,6 +5,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -67,55 +68,121 @@ const complex64_t Limits<complex64_t>::min =
|
||||
|
||||
struct AndReduce {
|
||||
template <typename T>
|
||||
void operator()(bool* a, T b) {
|
||||
(*a) &= (b != 0);
|
||||
bool operator()(bool x, T y) {
|
||||
return x & (y != 0);
|
||||
}
|
||||
|
||||
void operator()(bool* y, bool x) {
|
||||
(*y) &= x;
|
||||
bool operator()(bool x, bool y) {
|
||||
return x & y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x & (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x & y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::all(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct OrReduce {
|
||||
template <typename T>
|
||||
void operator()(bool* a, T b) {
|
||||
(*a) |= (b != 0);
|
||||
bool operator()(bool x, T y) {
|
||||
return x | (y != 0);
|
||||
}
|
||||
|
||||
void operator()(bool* y, bool x) {
|
||||
(*y) |= x;
|
||||
bool operator()(bool x, bool y) {
|
||||
return x | y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x | (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x | y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::any(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
}
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::maximum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
}
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::minimum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct SumReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x + y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return y + x;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::sum(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ProdReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::prod(x);
|
||||
};
|
||||
};
|
||||
|
||||
@ -139,18 +206,16 @@ void reduce_dispatch_sum_prod(
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
}
|
||||
} else {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, op);
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -195,7 +260,7 @@ void nd_loop(
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -60,45 +61,54 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const std::vector<int>& axes);
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultStridedReduce {
|
||||
Op op;
|
||||
|
||||
DefaultStridedReduce(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* x, U* accumulator, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
U* moving_accumulator = accumulator;
|
||||
for (int j = 0; j < stride; j++) {
|
||||
op(moving_accumulator, *x);
|
||||
moving_accumulator++;
|
||||
x++;
|
||||
}
|
||||
void strided_reduce(
|
||||
const T* x,
|
||||
U* accumulator,
|
||||
int size,
|
||||
size_t stride,
|
||||
Op op) {
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
for (int i = 0; i < size; i++) {
|
||||
U* moving_accumulator = accumulator;
|
||||
auto s = stride;
|
||||
while (s >= N) {
|
||||
auto acc = simd::load<U, N>(moving_accumulator);
|
||||
auto v = simd::Simd<U, N>(simd::load<T, N>(x));
|
||||
simd::store<U, N>(moving_accumulator, op(acc, v));
|
||||
moving_accumulator += N;
|
||||
x += N;
|
||||
s -= N;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultContiguousReduce {
|
||||
Op op;
|
||||
|
||||
DefaultContiguousReduce(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* x, U* accumulator, int size) {
|
||||
while (size-- > 0) {
|
||||
op(accumulator, *x);
|
||||
while (s-- > 0) {
|
||||
*moving_accumulator = op(*moving_accumulator, *x);
|
||||
moving_accumulator++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename OpS, typename OpC, typename Op>
|
||||
template <typename T, typename U, typename Op>
|
||||
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
simd::Simd<U, N> accumulator_v(init);
|
||||
while (size >= N) {
|
||||
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
|
||||
x += N;
|
||||
size -= N;
|
||||
}
|
||||
*accumulator = op(*accumulator, op(accumulator_v));
|
||||
while (size-- > 0) {
|
||||
*accumulator = op(*accumulator, *x);
|
||||
x++;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
OpS ops,
|
||||
OpC opc,
|
||||
Op op) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
@ -106,7 +116,7 @@ void reduction_op(
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
U* out_ptr = out.data<U>();
|
||||
*out_ptr = init;
|
||||
opc(x.data<T>(), out_ptr, x.size());
|
||||
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -116,7 +126,7 @@ void reduction_op(
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
opc(x_ptr, out_ptr, reduction_size);
|
||||
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -134,7 +144,7 @@ void reduction_op(
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
opc(x_ptr + offset, out_ptr, reduction_size);
|
||||
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
@ -142,7 +152,12 @@ void reduction_op(
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
|
||||
contiguous_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
op,
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
@ -160,7 +175,7 @@ void reduction_op(
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
ops(x_ptr, out_ptr, reduction_size, reduction_stride);
|
||||
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
|
||||
x_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
@ -180,7 +195,8 @@ void reduction_op(
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
|
||||
strided_reduce(
|
||||
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
@ -189,10 +205,12 @@ void reduction_op(
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
ops(x_ptr + offset + extra_offset,
|
||||
strided_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride);
|
||||
reduction_stride,
|
||||
op);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
@ -210,7 +228,9 @@ void reduction_op(
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
|
||||
[&](int extra_offset) {
|
||||
val = op(val, *(x_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
@ -218,16 +238,4 @@ void reduction_op(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Op op) {
|
||||
DefaultStridedReduce<T, U, Op> ops(op);
|
||||
DefaultContiguousReduce<T, U, Op> opc(op);
|
||||
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -11,184 +12,178 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultContiguousScan {
|
||||
Op op;
|
||||
U init;
|
||||
|
||||
DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {}
|
||||
|
||||
void operator()(
|
||||
const T* input,
|
||||
U* output,
|
||||
int count,
|
||||
int stride,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
if (!reverse) {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
*output = *input;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
input++;
|
||||
output++;
|
||||
op(output, output - 1, input);
|
||||
}
|
||||
output++;
|
||||
void contiguous_scan(
|
||||
const T* input,
|
||||
U* output,
|
||||
int count,
|
||||
int stride,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const Op& op,
|
||||
U init) {
|
||||
if (!reverse) {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
*output = *input;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
input++;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
*output = init;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
op(output + 1, output, input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
output++;
|
||||
input++;
|
||||
*output = op(*(output - 1), *input);
|
||||
}
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += stride - 1;
|
||||
input += stride - 1;
|
||||
*output = *input;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
input--;
|
||||
output--;
|
||||
op(output, output + 1, input);
|
||||
}
|
||||
output += stride;
|
||||
input += stride;
|
||||
for (int i = 0; i < count; i++) {
|
||||
*output = init;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
*(output + 1) = op(*output, *input);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += stride - 1;
|
||||
input += stride - 1;
|
||||
*output = init;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
op(output - 1, output, input);
|
||||
input--;
|
||||
output--;
|
||||
}
|
||||
output += stride;
|
||||
input += stride;
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += stride - 1;
|
||||
input += stride - 1;
|
||||
*output = *input;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
input--;
|
||||
output--;
|
||||
*output = op(*(output + 1), *input);
|
||||
}
|
||||
output += stride;
|
||||
input += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += stride - 1;
|
||||
input += stride - 1;
|
||||
*output = init;
|
||||
for (int j = 1; j < stride; j++) {
|
||||
*(output - 1) = op(*output, *input);
|
||||
input--;
|
||||
output--;
|
||||
}
|
||||
output += stride;
|
||||
input += stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultStridedScan {
|
||||
Op op;
|
||||
U init;
|
||||
|
||||
DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {}
|
||||
|
||||
void operator()(
|
||||
const T* input,
|
||||
U* output,
|
||||
int count,
|
||||
int size,
|
||||
int stride,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
// TODO: Vectorize the following naive implementation
|
||||
if (!reverse) {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
std::copy(input, input + stride, output);
|
||||
output += stride;
|
||||
input += stride;
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
op(output, output - stride, input);
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
std::fill(output, output + stride, init);
|
||||
output += stride;
|
||||
input += stride;
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
op(output, output - stride, input - stride);
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
void strided_scan(
|
||||
const T* input,
|
||||
U* output,
|
||||
int count,
|
||||
int size,
|
||||
int stride,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const Op& op,
|
||||
U init) {
|
||||
// TODO: Vectorize the following naive implementation
|
||||
if (!reverse) {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
std::copy(input, input + stride, output);
|
||||
output += stride;
|
||||
input += stride;
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
*output = op(*(output - stride), *input);
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += (size - 1) * stride;
|
||||
input += (size - 1) * stride;
|
||||
std::copy(input, input + stride, output);
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
output--;
|
||||
input--;
|
||||
op(output, output + stride, input);
|
||||
}
|
||||
for (int i = 0; i < count; i++) {
|
||||
std::fill(output, output + stride, init);
|
||||
output += stride;
|
||||
input += stride;
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
*output = op(*(output - stride), *(input - stride));
|
||||
output++;
|
||||
input++;
|
||||
}
|
||||
output += size * stride;
|
||||
input += size * stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += (size - 1) * stride;
|
||||
input += (size - 1) * stride;
|
||||
std::fill(output, output + stride, init);
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
output--;
|
||||
input--;
|
||||
op(output, output + stride, input + stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += (size - 1) * stride;
|
||||
input += (size - 1) * stride;
|
||||
std::copy(input, input + stride, output);
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
output--;
|
||||
input--;
|
||||
*output = op(*(output + stride), *input);
|
||||
}
|
||||
output += size * stride;
|
||||
input += size * stride;
|
||||
}
|
||||
output += size * stride;
|
||||
input += size * stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
output += (size - 1) * stride;
|
||||
input += (size - 1) * stride;
|
||||
std::fill(output, output + stride, init);
|
||||
for (int j = 1; j < size; j++) {
|
||||
for (int k = 0; k < stride; k++) {
|
||||
output--;
|
||||
input--;
|
||||
*output = op(*(output + stride), *(input + stride));
|
||||
}
|
||||
}
|
||||
output += size * stride;
|
||||
input += size * stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename OpCS, typename OpSS>
|
||||
template <typename T, typename U, typename Op>
|
||||
void scan_op(
|
||||
OpCS opcs,
|
||||
OpSS opss,
|
||||
const array& input,
|
||||
array& output,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
bool inclusive,
|
||||
const Op& op,
|
||||
U init) {
|
||||
output.set_data(allocator::malloc_or_wait(output.nbytes()));
|
||||
|
||||
if (input.flags().row_contiguous) {
|
||||
if (input.strides()[axis] == 1) {
|
||||
opcs(
|
||||
contiguous_scan(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis),
|
||||
input.shape(axis),
|
||||
reverse,
|
||||
inclusive);
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
} else {
|
||||
opss(
|
||||
strided_scan(
|
||||
input.data<T>(),
|
||||
output.data<U>(),
|
||||
input.size() / input.shape(axis) / input.strides()[axis],
|
||||
input.shape(axis),
|
||||
input.strides()[axis],
|
||||
reverse,
|
||||
inclusive);
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Scan op supports only contiguous inputs");
|
||||
@ -205,39 +200,31 @@ void scan_dispatch(
|
||||
bool inclusive) {
|
||||
switch (rtype) {
|
||||
case Scan::Sum: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; };
|
||||
auto op = [](U y, T x) { return y + x; };
|
||||
auto init = static_cast<U>(0);
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Prod: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); };
|
||||
auto op = [](U y, T x) { return y * x; };
|
||||
auto init = static_cast<U>(1);
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||
auto op = [](U y, T x) { return x < y ? x : y; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto op = [](U y, T x) { return x < y ? y : x; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -245,7 +232,7 @@ void scan_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void Scan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Ensure contiguity
|
||||
|
@ -267,6 +267,10 @@ Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
|
||||
|
||||
// Reductions
|
||||
|
||||
template <typename T, int N>
|
||||
bool all(Simd<T, N> x) {
|
||||
return asd::all(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
bool any(Simd<T, N> x) {
|
||||
return asd::any(x.value);
|
||||
@ -284,6 +288,14 @@ T min(Simd<T, N> x) {
|
||||
return asd::reduce_min(x.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
T prod(Simd<T, N> x) {
|
||||
auto ptr = (T*)&x;
|
||||
auto lhs = load<T, N / 2>(ptr);
|
||||
auto rhs = load<T, N / 2>(ptr + N / 2);
|
||||
return prod(lhs * rhs);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
@ -246,6 +246,7 @@ Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
|
||||
DEFAULT_REDUCTION(max, T)
|
||||
DEFAULT_REDUCTION(min, T)
|
||||
DEFAULT_REDUCTION(sum, T)
|
||||
DEFAULT_REDUCTION(prod, T)
|
||||
DEFAULT_REDUCTION(any, bool)
|
||||
DEFAULT_REDUCTION(all, bool)
|
||||
|
||||
|
@ -200,5 +200,13 @@ inline float16_t sum(Simd<float16_t, N> x) {
|
||||
y = vpadd_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t prod(Simd<float16_t, N> x) {
|
||||
auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
auto out = hx[0];
|
||||
hx[0] *= hx[1];
|
||||
hx[0] *= hx[2];
|
||||
hx[0] *= hx[3];
|
||||
return hx[0];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
|
@ -1691,8 +1691,6 @@ class Reduce : public UnaryPrimitive {
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
std::vector<int> axes_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Round : public UnaryPrimitive {
|
||||
@ -1758,8 +1756,6 @@ class Scan : public UnaryPrimitive {
|
||||
int axis_;
|
||||
bool reverse_;
|
||||
bool inclusive_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Scatter : public UnaryPrimitive {
|
||||
|
Loading…
Reference in New Issue
Block a user