mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 13:01:17 +08:00
parent
31e134be35
commit
fe3167d7ea
@ -46,6 +46,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
|
@ -196,6 +196,20 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
|
@ -8,7 +8,6 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
@ -314,20 +313,6 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
unary(in, out, detail::LogicalNot());
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -104,48 +104,14 @@ void reduce_dispatch_out(
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
reduction_op<InT, bool>(in, out, axes, false, op);
|
||||
break;
|
||||
case uint8:
|
||||
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint16:
|
||||
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint32:
|
||||
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint64:
|
||||
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int8:
|
||||
reduction_op<InT, int8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int16:
|
||||
reduction_op<InT, int16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int32:
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int64:
|
||||
reduction_op<InT, int64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case float16:
|
||||
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case float32:
|
||||
reduction_op<InT, float>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case complex64:
|
||||
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
|
||||
break;
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
} break;
|
||||
break;
|
||||
}
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
@ -168,6 +134,29 @@ void reduce_dispatch_out(
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -49,47 +49,18 @@ struct ReductionPlan {
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
inline void nd_loop(
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
const std::vector<size_t>& strides);
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
const std::vector<int>& axes);
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultStridedReduce {
|
||||
@ -123,102 +94,6 @@ struct DefaultContiguousReduce {
|
||||
}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1]) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename OpS, typename OpC, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
@ -361,6 +236,4 @@ void reduction_op(
|
||||
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
118
mlx/backend/common/reduce_utils.cpp
Normal file
118
mlx/backend/common/reduce_utils.cpp
Normal file
@ -0,0 +1,118 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1]) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -6,5 +6,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user