mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
Remove accelerate/ (#1816)
* remove accelerate * comments * neon reduction
This commit is contained in:
parent
f5cc1eea72
commit
80c863b972
@ -37,13 +37,6 @@ endif()
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
if(MLX_BUILD_ACCELERATE)
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
|
||||||
elseif(MLX_BUILD_CPU)
|
|
||||||
target_sources(
|
|
||||||
mlx
|
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp)
|
|
@ -1,47 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/backend/common/binary.h"
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/backend/common/unary.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
|
||||||
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
int stride = in.shape(axis_);
|
|
||||||
int count = in.size() / stride;
|
|
||||||
const float* input = in.data<float>();
|
|
||||||
float* output = out.data<float>();
|
|
||||||
float s = 1.0;
|
|
||||||
if (!reverse_) {
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
|
||||||
input += stride;
|
|
||||||
output += stride;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
input += stride - 1;
|
|
||||||
output += stride - 1;
|
|
||||||
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@ -1,139 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
#include <simd/vector.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct MinReduction {
|
|
||||||
T operator()(const T& a, const T& b) {
|
|
||||||
return std::min(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT operator()(VT a, VT b) {
|
|
||||||
return simd_min(a, b);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct MaxReduction {
|
|
||||||
T operator()(const T& a, const T& b) {
|
|
||||||
return std::max(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT operator()(VT a, VT b) {
|
|
||||||
return simd_max(a, b);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct SumReduction {
|
|
||||||
T operator()(const T& a, const T& b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT operator()(VT a, VT b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename VT, int N, typename Reduction>
|
|
||||||
struct StridedReduce {
|
|
||||||
void operator()(const T* x, T* accum, int size, size_t stride) {
|
|
||||||
Reduction op;
|
|
||||||
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
size_t s = stride;
|
|
||||||
T* a = accum;
|
|
||||||
while (s >= N) {
|
|
||||||
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
|
||||||
x += N;
|
|
||||||
a += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
while (s-- > 0) {
|
|
||||||
*a = op(*a, *x);
|
|
||||||
a++;
|
|
||||||
x++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
|
|
||||||
if (in.dtype() == float32) {
|
|
||||||
if (reduce_type_ == Reduce::Sum) {
|
|
||||||
reduction_op<float, float>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
axes_,
|
|
||||||
0,
|
|
||||||
StridedReduce<
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
16,
|
|
||||||
SumReduction<float, simd_float16>>(),
|
|
||||||
[](const auto* x, auto* accum, int size) {
|
|
||||||
float acc;
|
|
||||||
vDSP_sve((const float*)x, 1, &acc, size);
|
|
||||||
(*accum) += acc;
|
|
||||||
},
|
|
||||||
[](auto* accum, auto x) { *accum += x; });
|
|
||||||
return;
|
|
||||||
} else if (reduce_type_ == Reduce::Max) {
|
|
||||||
reduction_op<float, float>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
axes_,
|
|
||||||
-std::numeric_limits<float>::infinity(),
|
|
||||||
StridedReduce<
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
16,
|
|
||||||
MaxReduction<float, simd_float16>>(),
|
|
||||||
[](const auto* x, auto* accum, int size) {
|
|
||||||
float max;
|
|
||||||
vDSP_maxv((const float*)x, 1, &max, size);
|
|
||||||
(*accum) = (*accum < max) ? max : *accum;
|
|
||||||
},
|
|
||||||
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
|
||||||
return;
|
|
||||||
} else if (reduce_type_ == Reduce::Min) {
|
|
||||||
reduction_op<float, float>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
axes_,
|
|
||||||
std::numeric_limits<float>::infinity(),
|
|
||||||
StridedReduce<
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
16,
|
|
||||||
MinReduction<float, simd_float16>>(),
|
|
||||||
[](const auto* x, auto* accum, int size) {
|
|
||||||
float min;
|
|
||||||
vDSP_minv((const float*)x, 1, &min, size);
|
|
||||||
(*accum) = (*accum > min) ? min : *accum;
|
|
||||||
},
|
|
||||||
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// TODO: Add integer addition and min/max using the templates above and
|
|
||||||
// simd_int16 and friends.
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@ -1,22 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#define DEFAULT(primitive) \
|
|
||||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
|
||||||
primitive::eval(inputs, out); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define DEFAULT_MULTI(primitive) \
|
|
||||||
void primitive::eval_cpu( \
|
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
|
||||||
primitive::eval(inputs, outputs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
DEFAULT(Reduce)
|
|
||||||
DEFAULT(Scan)
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@ -5,6 +5,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/backend/common/simd/simd.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -67,55 +68,121 @@ const complex64_t Limits<complex64_t>::min =
|
|||||||
|
|
||||||
struct AndReduce {
|
struct AndReduce {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void operator()(bool* a, T b) {
|
bool operator()(bool x, T y) {
|
||||||
(*a) &= (b != 0);
|
return x & (y != 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator()(bool* y, bool x) {
|
bool operator()(bool x, bool y) {
|
||||||
(*y) &= x;
|
return x & y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||||
|
return x & (y != 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||||
|
return x & y;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
bool operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::all(x);
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OrReduce {
|
struct OrReduce {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void operator()(bool* a, T b) {
|
bool operator()(bool x, T y) {
|
||||||
(*a) |= (b != 0);
|
return x | (y != 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator()(bool* y, bool x) {
|
bool operator()(bool x, bool y) {
|
||||||
(*y) |= x;
|
return x | y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||||
|
return x | (y != 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||||
|
return x | y;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
bool operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::any(x);
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MaxReduce {
|
struct MaxReduce {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
T operator()(T y, T x) {
|
||||||
(*y) = (*y > x) ? *y : x;
|
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <int N, typename T>
|
||||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||||
if (std::isnan(x)) {
|
return simd::maximum(x, y);
|
||||||
*y = x;
|
};
|
||||||
} else {
|
|
||||||
(*y) = (*y > x) ? *y : x;
|
template <int N, typename T>
|
||||||
}
|
T operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::max(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MinReduce {
|
struct MinReduce {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
T operator()(T y, T x) {
|
||||||
(*y) = (*y < x) ? *y : x;
|
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <int N, typename T>
|
||||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||||
if (std::isnan(x)) {
|
return simd::minimum(x, y);
|
||||||
*y = x;
|
};
|
||||||
} else {
|
|
||||||
(*y) = (*y < x) ? *y : x;
|
template <int N, typename T>
|
||||||
}
|
T operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::min(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SumReduce {
|
||||||
|
template <typename T, typename U>
|
||||||
|
U operator()(U y, T x) {
|
||||||
|
return x + y;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T, typename U>
|
||||||
|
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||||
|
return y + x;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
T operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::sum(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ProdReduce {
|
||||||
|
template <typename T, typename U>
|
||||||
|
U operator()(U y, T x) {
|
||||||
|
return x * y;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T, typename U>
|
||||||
|
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||||
|
return x * y;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
T operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::prod(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -139,18 +206,16 @@ void reduce_dispatch_sum_prod(
|
|||||||
Reduce::ReduceType rtype,
|
Reduce::ReduceType rtype,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
if (rtype == Reduce::Sum) {
|
if (rtype == Reduce::Sum) {
|
||||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
|
||||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||||
} else {
|
} else {
|
||||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
|
||||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||||
reduction_op<InT, int32_t>(in, out, axes, 1, op);
|
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||||
} else {
|
} else {
|
||||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -195,7 +260,7 @@ void nd_loop(
|
|||||||
loop_inner(0, 0);
|
loop_inner(0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/simd/simd.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -60,45 +61,54 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
struct DefaultStridedReduce {
|
void strided_reduce(
|
||||||
Op op;
|
const T* x,
|
||||||
|
U* accumulator,
|
||||||
DefaultStridedReduce(Op op_) : op(op_) {}
|
int size,
|
||||||
|
size_t stride,
|
||||||
void operator()(const T* x, U* accumulator, int size, size_t stride) {
|
Op op) {
|
||||||
|
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
U* moving_accumulator = accumulator;
|
U* moving_accumulator = accumulator;
|
||||||
for (int j = 0; j < stride; j++) {
|
auto s = stride;
|
||||||
op(moving_accumulator, *x);
|
while (s >= N) {
|
||||||
|
auto acc = simd::load<U, N>(moving_accumulator);
|
||||||
|
auto v = simd::Simd<U, N>(simd::load<T, N>(x));
|
||||||
|
simd::store<U, N>(moving_accumulator, op(acc, v));
|
||||||
|
moving_accumulator += N;
|
||||||
|
x += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
while (s-- > 0) {
|
||||||
|
*moving_accumulator = op(*moving_accumulator, *x);
|
||||||
moving_accumulator++;
|
moving_accumulator++;
|
||||||
x++;
|
x++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
struct DefaultContiguousReduce {
|
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
|
||||||
Op op;
|
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||||
|
simd::Simd<U, N> accumulator_v(init);
|
||||||
DefaultContiguousReduce(Op op_) : op(op_) {}
|
while (size >= N) {
|
||||||
|
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
|
||||||
void operator()(const T* x, U* accumulator, int size) {
|
x += N;
|
||||||
|
size -= N;
|
||||||
|
}
|
||||||
|
*accumulator = op(*accumulator, op(accumulator_v));
|
||||||
while (size-- > 0) {
|
while (size-- > 0) {
|
||||||
op(accumulator, *x);
|
*accumulator = op(*accumulator, *x);
|
||||||
x++;
|
x++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename U, typename OpS, typename OpC, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
void reduction_op(
|
void reduction_op(
|
||||||
const array& x,
|
const array& x,
|
||||||
array& out,
|
array& out,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
U init,
|
U init,
|
||||||
OpS ops,
|
|
||||||
OpC opc,
|
|
||||||
Op op) {
|
Op op) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||||
@ -106,7 +116,7 @@ void reduction_op(
|
|||||||
if (plan.type == ContiguousAllReduce) {
|
if (plan.type == ContiguousAllReduce) {
|
||||||
U* out_ptr = out.data<U>();
|
U* out_ptr = out.data<U>();
|
||||||
*out_ptr = init;
|
*out_ptr = init;
|
||||||
opc(x.data<T>(), out_ptr, x.size());
|
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,7 +126,7 @@ void reduction_op(
|
|||||||
U* out_ptr = out.data<U>();
|
U* out_ptr = out.data<U>();
|
||||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||||
*out_ptr = init;
|
*out_ptr = init;
|
||||||
opc(x_ptr, out_ptr, reduction_size);
|
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -134,7 +144,7 @@ void reduction_op(
|
|||||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||||
int offset = elem_to_loc(i, shape, strides);
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
*out_ptr = init;
|
*out_ptr = init;
|
||||||
opc(x_ptr + offset, out_ptr, reduction_size);
|
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||||
@ -142,7 +152,12 @@ void reduction_op(
|
|||||||
*out_ptr = init;
|
*out_ptr = init;
|
||||||
nd_loop(
|
nd_loop(
|
||||||
[&](int extra_offset) {
|
[&](int extra_offset) {
|
||||||
opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
|
contiguous_reduce(
|
||||||
|
x_ptr + offset + extra_offset,
|
||||||
|
out_ptr,
|
||||||
|
reduction_size,
|
||||||
|
op,
|
||||||
|
init);
|
||||||
},
|
},
|
||||||
plan.shape,
|
plan.shape,
|
||||||
plan.strides);
|
plan.strides);
|
||||||
@ -160,7 +175,7 @@ void reduction_op(
|
|||||||
U* out_ptr = out.data<U>();
|
U* out_ptr = out.data<U>();
|
||||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||||
std::fill_n(out_ptr, reduction_stride, init);
|
std::fill_n(out_ptr, reduction_stride, init);
|
||||||
ops(x_ptr, out_ptr, reduction_size, reduction_stride);
|
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
|
||||||
x_ptr += reduction_stride * reduction_size;
|
x_ptr += reduction_stride * reduction_size;
|
||||||
out_ptr += reduction_stride;
|
out_ptr += reduction_stride;
|
||||||
}
|
}
|
||||||
@ -180,7 +195,8 @@ void reduction_op(
|
|||||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||||
int offset = elem_to_loc(i, shape, strides);
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
std::fill_n(out_ptr, reduction_stride, init);
|
std::fill_n(out_ptr, reduction_stride, init);
|
||||||
ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
|
strided_reduce(
|
||||||
|
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
|
||||||
out_ptr += reduction_stride;
|
out_ptr += reduction_stride;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -189,10 +205,12 @@ void reduction_op(
|
|||||||
std::fill_n(out_ptr, reduction_stride, init);
|
std::fill_n(out_ptr, reduction_stride, init);
|
||||||
nd_loop(
|
nd_loop(
|
||||||
[&](int extra_offset) {
|
[&](int extra_offset) {
|
||||||
ops(x_ptr + offset + extra_offset,
|
strided_reduce(
|
||||||
|
x_ptr + offset + extra_offset,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
reduction_size,
|
reduction_size,
|
||||||
reduction_stride);
|
reduction_stride,
|
||||||
|
op);
|
||||||
},
|
},
|
||||||
plan.shape,
|
plan.shape,
|
||||||
plan.strides);
|
plan.strides);
|
||||||
@ -210,7 +228,9 @@ void reduction_op(
|
|||||||
int offset = elem_to_loc(i, shape, strides);
|
int offset = elem_to_loc(i, shape, strides);
|
||||||
U val = init;
|
U val = init;
|
||||||
nd_loop(
|
nd_loop(
|
||||||
[&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
|
[&](int extra_offset) {
|
||||||
|
val = op(val, *(x_ptr + offset + extra_offset));
|
||||||
|
},
|
||||||
plan.shape,
|
plan.shape,
|
||||||
plan.strides);
|
plan.strides);
|
||||||
*out_ptr = val;
|
*out_ptr = val;
|
||||||
@ -218,16 +238,4 @@ void reduction_op(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
void reduction_op(
|
|
||||||
const array& x,
|
|
||||||
array& out,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
U init,
|
|
||||||
Op op) {
|
|
||||||
DefaultStridedReduce<T, U, Op> ops(op);
|
|
||||||
DefaultContiguousReduce<T, U, Op> opc(op);
|
|
||||||
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/simd/simd.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@ -11,19 +12,15 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
struct DefaultContiguousScan {
|
void contiguous_scan(
|
||||||
Op op;
|
|
||||||
U init;
|
|
||||||
|
|
||||||
DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {}
|
|
||||||
|
|
||||||
void operator()(
|
|
||||||
const T* input,
|
const T* input,
|
||||||
U* output,
|
U* output,
|
||||||
int count,
|
int count,
|
||||||
int stride,
|
int stride,
|
||||||
bool reverse,
|
bool reverse,
|
||||||
bool inclusive) {
|
bool inclusive,
|
||||||
|
const Op& op,
|
||||||
|
U init) {
|
||||||
if (!reverse) {
|
if (!reverse) {
|
||||||
if (inclusive) {
|
if (inclusive) {
|
||||||
for (int i = 0; i < count; i++) {
|
for (int i = 0; i < count; i++) {
|
||||||
@ -31,7 +28,7 @@ struct DefaultContiguousScan {
|
|||||||
for (int j = 1; j < stride; j++) {
|
for (int j = 1; j < stride; j++) {
|
||||||
input++;
|
input++;
|
||||||
output++;
|
output++;
|
||||||
op(output, output - 1, input);
|
*output = op(*(output - 1), *input);
|
||||||
}
|
}
|
||||||
output++;
|
output++;
|
||||||
input++;
|
input++;
|
||||||
@ -40,7 +37,7 @@ struct DefaultContiguousScan {
|
|||||||
for (int i = 0; i < count; i++) {
|
for (int i = 0; i < count; i++) {
|
||||||
*output = init;
|
*output = init;
|
||||||
for (int j = 1; j < stride; j++) {
|
for (int j = 1; j < stride; j++) {
|
||||||
op(output + 1, output, input);
|
*(output + 1) = op(*output, *input);
|
||||||
input++;
|
input++;
|
||||||
output++;
|
output++;
|
||||||
}
|
}
|
||||||
@ -57,7 +54,7 @@ struct DefaultContiguousScan {
|
|||||||
for (int j = 1; j < stride; j++) {
|
for (int j = 1; j < stride; j++) {
|
||||||
input--;
|
input--;
|
||||||
output--;
|
output--;
|
||||||
op(output, output + 1, input);
|
*output = op(*(output + 1), *input);
|
||||||
}
|
}
|
||||||
output += stride;
|
output += stride;
|
||||||
input += stride;
|
input += stride;
|
||||||
@ -68,7 +65,7 @@ struct DefaultContiguousScan {
|
|||||||
input += stride - 1;
|
input += stride - 1;
|
||||||
*output = init;
|
*output = init;
|
||||||
for (int j = 1; j < stride; j++) {
|
for (int j = 1; j < stride; j++) {
|
||||||
op(output - 1, output, input);
|
*(output - 1) = op(*output, *input);
|
||||||
input--;
|
input--;
|
||||||
output--;
|
output--;
|
||||||
}
|
}
|
||||||
@ -77,24 +74,19 @@ struct DefaultContiguousScan {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
struct DefaultStridedScan {
|
void strided_scan(
|
||||||
Op op;
|
|
||||||
U init;
|
|
||||||
|
|
||||||
DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {}
|
|
||||||
|
|
||||||
void operator()(
|
|
||||||
const T* input,
|
const T* input,
|
||||||
U* output,
|
U* output,
|
||||||
int count,
|
int count,
|
||||||
int size,
|
int size,
|
||||||
int stride,
|
int stride,
|
||||||
bool reverse,
|
bool reverse,
|
||||||
bool inclusive) {
|
bool inclusive,
|
||||||
|
const Op& op,
|
||||||
|
U init) {
|
||||||
// TODO: Vectorize the following naive implementation
|
// TODO: Vectorize the following naive implementation
|
||||||
if (!reverse) {
|
if (!reverse) {
|
||||||
if (inclusive) {
|
if (inclusive) {
|
||||||
@ -104,7 +96,7 @@ struct DefaultStridedScan {
|
|||||||
input += stride;
|
input += stride;
|
||||||
for (int j = 1; j < size; j++) {
|
for (int j = 1; j < size; j++) {
|
||||||
for (int k = 0; k < stride; k++) {
|
for (int k = 0; k < stride; k++) {
|
||||||
op(output, output - stride, input);
|
*output = op(*(output - stride), *input);
|
||||||
output++;
|
output++;
|
||||||
input++;
|
input++;
|
||||||
}
|
}
|
||||||
@ -117,7 +109,7 @@ struct DefaultStridedScan {
|
|||||||
input += stride;
|
input += stride;
|
||||||
for (int j = 1; j < size; j++) {
|
for (int j = 1; j < size; j++) {
|
||||||
for (int k = 0; k < stride; k++) {
|
for (int k = 0; k < stride; k++) {
|
||||||
op(output, output - stride, input - stride);
|
*output = op(*(output - stride), *(input - stride));
|
||||||
output++;
|
output++;
|
||||||
input++;
|
input++;
|
||||||
}
|
}
|
||||||
@ -134,7 +126,7 @@ struct DefaultStridedScan {
|
|||||||
for (int k = 0; k < stride; k++) {
|
for (int k = 0; k < stride; k++) {
|
||||||
output--;
|
output--;
|
||||||
input--;
|
input--;
|
||||||
op(output, output + stride, input);
|
*output = op(*(output + stride), *input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output += size * stride;
|
output += size * stride;
|
||||||
@ -149,7 +141,7 @@ struct DefaultStridedScan {
|
|||||||
for (int k = 0; k < stride; k++) {
|
for (int k = 0; k < stride; k++) {
|
||||||
output--;
|
output--;
|
||||||
input--;
|
input--;
|
||||||
op(output, output + stride, input + stride);
|
*output = op(*(output + stride), *(input + stride));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output += size * stride;
|
output += size * stride;
|
||||||
@ -157,38 +149,41 @@ struct DefaultStridedScan {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename OpCS, typename OpSS>
|
template <typename T, typename U, typename Op>
|
||||||
void scan_op(
|
void scan_op(
|
||||||
OpCS opcs,
|
|
||||||
OpSS opss,
|
|
||||||
const array& input,
|
const array& input,
|
||||||
array& output,
|
array& output,
|
||||||
int axis,
|
int axis,
|
||||||
bool reverse,
|
bool reverse,
|
||||||
bool inclusive) {
|
bool inclusive,
|
||||||
|
const Op& op,
|
||||||
|
U init) {
|
||||||
output.set_data(allocator::malloc_or_wait(output.nbytes()));
|
output.set_data(allocator::malloc_or_wait(output.nbytes()));
|
||||||
|
|
||||||
if (input.flags().row_contiguous) {
|
if (input.flags().row_contiguous) {
|
||||||
if (input.strides()[axis] == 1) {
|
if (input.strides()[axis] == 1) {
|
||||||
opcs(
|
contiguous_scan(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
output.data<U>(),
|
output.data<U>(),
|
||||||
input.size() / input.shape(axis),
|
input.size() / input.shape(axis),
|
||||||
input.shape(axis),
|
input.shape(axis),
|
||||||
reverse,
|
reverse,
|
||||||
inclusive);
|
inclusive,
|
||||||
|
op,
|
||||||
|
init);
|
||||||
} else {
|
} else {
|
||||||
opss(
|
strided_scan(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
output.data<U>(),
|
output.data<U>(),
|
||||||
input.size() / input.shape(axis) / input.strides()[axis],
|
input.size() / input.shape(axis) / input.strides()[axis],
|
||||||
input.shape(axis),
|
input.shape(axis),
|
||||||
input.strides()[axis],
|
input.strides()[axis],
|
||||||
reverse,
|
reverse,
|
||||||
inclusive);
|
inclusive,
|
||||||
|
op,
|
||||||
|
init);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Scan op supports only contiguous inputs");
|
throw std::runtime_error("Scan op supports only contiguous inputs");
|
||||||
@ -205,39 +200,31 @@ void scan_dispatch(
|
|||||||
bool inclusive) {
|
bool inclusive) {
|
||||||
switch (rtype) {
|
switch (rtype) {
|
||||||
case Scan::Sum: {
|
case Scan::Sum: {
|
||||||
auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; };
|
auto op = [](U y, T x) { return y + x; };
|
||||||
auto init = static_cast<U>(0);
|
auto init = static_cast<U>(0);
|
||||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
|
||||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Scan::Prod: {
|
case Scan::Prod: {
|
||||||
auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); };
|
auto op = [](U y, T x) { return y * x; };
|
||||||
auto init = static_cast<U>(1);
|
auto init = static_cast<U>(1);
|
||||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
|
||||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Scan::Min: {
|
case Scan::Min: {
|
||||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
auto op = [](U y, T x) { return x < y ? x : y; };
|
||||||
auto init = (issubdtype(input.dtype(), floating))
|
auto init = (issubdtype(input.dtype(), floating))
|
||||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||||
: std::numeric_limits<U>::max();
|
: std::numeric_limits<U>::max();
|
||||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
|
||||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Scan::Max: {
|
case Scan::Max: {
|
||||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
auto op = [](U y, T x) { return x < y ? y : x; };
|
||||||
auto init = (issubdtype(input.dtype(), floating))
|
auto init = (issubdtype(input.dtype(), floating))
|
||||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||||
: std::numeric_limits<U>::min();
|
: std::numeric_limits<U>::min();
|
||||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
|
||||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
|
||||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -245,7 +232,7 @@ void scan_dispatch(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Scan::eval(const std::vector<array>& inputs, array& out) {
|
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
// Ensure contiguity
|
// Ensure contiguity
|
||||||
|
@ -267,6 +267,10 @@ Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
|
|||||||
|
|
||||||
// Reductions
|
// Reductions
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
bool all(Simd<T, N> x) {
|
||||||
|
return asd::all(x.value);
|
||||||
|
}
|
||||||
template <typename T, int N>
|
template <typename T, int N>
|
||||||
bool any(Simd<T, N> x) {
|
bool any(Simd<T, N> x) {
|
||||||
return asd::any(x.value);
|
return asd::any(x.value);
|
||||||
@ -284,6 +288,14 @@ T min(Simd<T, N> x) {
|
|||||||
return asd::reduce_min(x.value);
|
return asd::reduce_min(x.value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
T prod(Simd<T, N> x) {
|
||||||
|
auto ptr = (T*)&x;
|
||||||
|
auto lhs = load<T, N / 2>(ptr);
|
||||||
|
auto rhs = load<T, N / 2>(ptr + N / 2);
|
||||||
|
return prod(lhs * rhs);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::simd
|
} // namespace mlx::core::simd
|
||||||
|
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
@ -246,6 +246,7 @@ Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
|
|||||||
DEFAULT_REDUCTION(max, T)
|
DEFAULT_REDUCTION(max, T)
|
||||||
DEFAULT_REDUCTION(min, T)
|
DEFAULT_REDUCTION(min, T)
|
||||||
DEFAULT_REDUCTION(sum, T)
|
DEFAULT_REDUCTION(sum, T)
|
||||||
|
DEFAULT_REDUCTION(prod, T)
|
||||||
DEFAULT_REDUCTION(any, bool)
|
DEFAULT_REDUCTION(any, bool)
|
||||||
DEFAULT_REDUCTION(all, bool)
|
DEFAULT_REDUCTION(all, bool)
|
||||||
|
|
||||||
|
@ -200,5 +200,13 @@ inline float16_t sum(Simd<float16_t, N> x) {
|
|||||||
y = vpadd_f16(y, y);
|
y = vpadd_f16(y, y);
|
||||||
return vget_lane_f16(y, 0);
|
return vget_lane_f16(y, 0);
|
||||||
}
|
}
|
||||||
|
inline float16_t prod(Simd<float16_t, N> x) {
|
||||||
|
auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||||
|
auto out = hx[0];
|
||||||
|
hx[0] *= hx[1];
|
||||||
|
hx[0] *= hx[2];
|
||||||
|
hx[0] *= hx[3];
|
||||||
|
return hx[0];
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::simd
|
} // namespace mlx::core::simd
|
||||||
|
@ -1691,8 +1691,6 @@ class Reduce : public UnaryPrimitive {
|
|||||||
private:
|
private:
|
||||||
ReduceType reduce_type_;
|
ReduceType reduce_type_;
|
||||||
std::vector<int> axes_;
|
std::vector<int> axes_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Round : public UnaryPrimitive {
|
class Round : public UnaryPrimitive {
|
||||||
@ -1758,8 +1756,6 @@ class Scan : public UnaryPrimitive {
|
|||||||
int axis_;
|
int axis_;
|
||||||
bool reverse_;
|
bool reverse_;
|
||||||
bool inclusive_;
|
bool inclusive_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Scatter : public UnaryPrimitive {
|
class Scatter : public UnaryPrimitive {
|
||||||
|
Loading…
Reference in New Issue
Block a user