From fe3167d7ea0e9e2634683a12aba0a8087cc6cb57 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 14 Jun 2024 09:46:55 -0700 Subject: [PATCH] smaller CPU binary (#1203) * smaller CPU binary * fix no cpu build --- mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/binary.cpp | 14 +++ mlx/backend/common/primitives.cpp | 15 ---- mlx/backend/common/reduce.cpp | 71 +++++++-------- mlx/backend/common/reduce.h | 135 +--------------------------- mlx/backend/common/reduce_utils.cpp | 118 ++++++++++++++++++++++++ mlx/backend/no_cpu/CMakeLists.txt | 1 + 7 files changed, 168 insertions(+), 187 deletions(-) create mode 100644 mlx/backend/common/reduce_utils.cpp diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 5a9fcc5ba..b5650b395 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -46,6 +46,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index eba96fc5d..517c7e8a0 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -196,6 +196,20 @@ void LogAddExp::eval(const std::vector& inputs, array& out) { } } +void LogicalAnd::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); // LogicalAnd requires two input arrays + auto& in1 = inputs[0]; + auto& in2 = inputs[1]; + binary(in1, in2, out, detail::LogicalAnd()); +} + +void LogicalOr::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); // LogicalOr requires two input arrays + auto& in1 = inputs[0]; + auto& in2 = inputs[1]; + binary(in1, in2, out, detail::LogicalOr()); +} + void Maximum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index d6b667c8d..a01717e73 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -8,7 +8,6 @@ #include "mlx/allocator.h" #include "mlx/backend/common/arange.h" -#include "mlx/backend/common/binary.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/ops.h" #include "mlx/backend/common/slicing.h" @@ -314,20 +313,6 @@ void LogicalNot::eval(const std::vector& inputs, array& out) { unary(in, out, detail::LogicalNot()); } -void LogicalAnd::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); // LogicalAnd requires two input arrays - auto& in1 = inputs[0]; - auto& in2 = inputs[1]; - binary(in1, in2, out, detail::LogicalAnd()); -} - -void LogicalOr::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); // LogicalOr requires two input arrays - auto& in1 = inputs[0]; - auto& in2 = inputs[1]; - binary(in1, in2, out, detail::LogicalOr()); -} - void Negative::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index e9d091e49..070262c68 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -104,48 +104,14 @@ void reduce_dispatch_out( } case Reduce::Sum: { auto op = [](auto y, auto x) { (*y) = (*y) + x; }; - switch (out.dtype()) { - case bool_: - reduction_op(in, out, axes, false, op); - break; - case uint8: - reduction_op(in, out, axes, 0, op); - break; - case uint16: - reduction_op(in, out, axes, 0, op); - break; - case uint32: - reduction_op(in, out, axes, 0, op); - break; - case uint64: - reduction_op(in, out, axes, 0, op); - break; - case int8: - reduction_op(in, out, axes, 0, op); - break; - case int16: - reduction_op(in, out, axes, 0, op); - break; - case int32: - reduction_op(in, out, axes, 0, op); - break; - case int64: - reduction_op(in, out, axes, 0, op); - break; - case float16: - reduction_op(in, out, axes, 0.0f, op); - break; - case float32: - reduction_op(in, out, axes, 0.0f, op); - break; - case bfloat16: - reduction_op(in, out, axes, 0.0f, op); - break; - case complex64: - reduction_op(in, out, axes, complex64_t{0.0f}, op); - break; + if (out.dtype() == int32) { + // special case since the input type can be bool + reduction_op(in, out, axes, 0, op); + } else { + reduction_op(in, out, axes, 0, op); } - } break; + break; + } case Reduce::Prod: { auto op = [](auto y, auto x) { (*y) *= x; }; reduction_op(in, out, axes, 1, op); @@ -168,6 +134,29 @@ void reduce_dispatch_out( } // namespace +void nd_loop( + std::function callback, + const std::vector& shape, + const std::vector& strides) { + std::function loop_inner; + loop_inner = [&](int dim, int offset) { + if (dim < shape.size() - 1) { + int size = shape[dim]; + size_t stride = strides[dim]; + for (int i = 0; i < size; i++) { + loop_inner(dim + 1, offset + i * stride); + } + } else { + int size = shape[dim]; + size_t stride = strides[dim]; + for (int i = 0; i < size; i++) { + callback(offset + i * stride); + } + } + }; + loop_inner(0, 0); +} + void Reduce::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index 19bf652a8..0e9088579 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -49,47 +49,18 @@ struct ReductionPlan { ReductionPlan(ReductionOpType type_) : type(type_) {} }; -namespace { +ReductionPlan get_reduction_plan(const array& x, const std::vector axes); // Helper for the ndimensional strided loop // Should this be in utils? -inline void nd_loop( +void nd_loop( std::function callback, const std::vector& shape, - const std::vector& strides) { - std::function loop_inner; - loop_inner = [&](int dim, int offset) { - if (dim < shape.size() - 1) { - int size = shape[dim]; - size_t stride = strides[dim]; - for (int i = 0; i < size; i++) { - loop_inner(dim + 1, offset + i * stride); - } - } else { - int size = shape[dim]; - size_t stride = strides[dim]; - for (int i = 0; i < size; i++) { - callback(offset + i * stride); - } - } - }; - loop_inner(0, 0); -} + const std::vector& strides); std::pair, std::vector> shapes_without_reduction_axes( const array& x, - const std::vector& axes) { - std::vector shape = x.shape(); - std::vector strides = x.strides(); - - for (int i = axes.size() - 1; i >= 0; i--) { - int a = axes[i]; - shape.erase(shape.begin() + a); - strides.erase(strides.begin() + a); - } - - return std::make_pair(shape, strides); -} + const std::vector& axes); template struct DefaultStridedReduce { @@ -123,102 +94,6 @@ struct DefaultContiguousReduce { } }; -ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { - // The data is all there and we are reducing over everything - if (x.size() == x.data_size() && axes.size() == x.ndim() && - x.flags().contiguous) { - return ContiguousAllReduce; - } - - // Row contiguous input so the output is row contiguous - if (x.flags().row_contiguous) { - // Merge consecutive axes - std::vector shape = {x.shape(axes[0])}; - std::vector strides = {x.strides()[axes[0]]}; - for (int i = 1; i < axes.size(); i++) { - if (axes[i] - 1 == axes[i - 1]) { - shape.back() *= x.shape(axes[i]); - strides.back() = x.strides()[axes[i]]; - } else { - shape.push_back(x.shape(axes[i])); - strides.push_back(x.strides()[axes[i]]); - } - } - - if (strides.back() == 1) { - return ReductionPlan(ContiguousReduce, shape, strides); - } else if (strides.back() > 1) { - return ReductionPlan(ContiguousStridedReduce, shape, strides); - } - } - - // Let's check if we can optimize our access patterns - // - // 1. We have a reduction axis with stride 1. Simply call - // GeneralContiguousReduce and be done with it. - // 2. We have transpositions and we are not reducing over the axis with - // stride 1. However, we are reducing over an axis where everything is - // contiguous in memory to the right of that axis. We can call strided - // reduce and be done with it. - // 2. We have weird transpositions and expands. Copy the strides to the - // output, then call strided reduce. - - // Sort reduction axes by stride in order to merge them and figure out if we - // have a contiguous reduction. - std::vector> reductions; - for (auto a : axes) { - reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); - } - std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { - return a.second > b.second; - }); - // Extract the two smallest and try to merge them in case the contiguous - // reduction can be bigger than just the last axis. - for (int i = reductions.size() - 1; i >= 1; i--) { - auto a = reductions[i]; - auto b = reductions[i - 1]; - - // b.stride = a.shape * a.stride then a and b are contiguous - if (b.second == a.first * a.second) { - reductions.erase(reductions.begin() + i); - reductions[i - 1] = std::make_pair(a.first * b.first, a.second); - } - } - - std::vector shape; - std::vector strides; - for (auto r : reductions) { - shape.push_back(r.first); - strides.push_back(r.second); - } - - // We can call the contiguous reduction op for every weird way the input is - // structured in the rest of the axes. - if (strides.back() == 1) { - return ReductionPlan(GeneralContiguousReduce, shape, strides); - } - - // Delegate to the general strided reduction op if the axes after - // strides.back() are contiguous. - if (strides.back() > 1) { - int size = 1; - for (int i = x.ndim() - 1; i >= 0; i--) { - if (axes.back() == i) { - continue; - } - if (x.strides()[i] != size) { - break; - } - size *= x.shape(i); - } - if (size >= strides.back()) { - return ReductionPlan(GeneralStridedReduce, shape, strides); - } - } - - return ReductionPlan(GeneralReduce, shape, strides); -} - template void reduction_op( const array& x, @@ -361,6 +236,4 @@ void reduction_op( reduction_op(x, out, axes, init, ops, opc, op); } -} // namespace - } // namespace mlx::core diff --git a/mlx/backend/common/reduce_utils.cpp b/mlx/backend/common/reduce_utils.cpp new file mode 100644 index 000000000..47b0f6c32 --- /dev/null +++ b/mlx/backend/common/reduce_utils.cpp @@ -0,0 +1,118 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/reduce.h" + +namespace mlx::core { + +std::pair, std::vector> shapes_without_reduction_axes( + const array& x, + const std::vector& axes) { + std::vector shape = x.shape(); + std::vector strides = x.strides(); + + for (int i = axes.size() - 1; i >= 0; i--) { + int a = axes[i]; + shape.erase(shape.begin() + a); + strides.erase(strides.begin() + a); + } + + return std::make_pair(shape, strides); +} + +ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { + // The data is all there and we are reducing over everything + if (x.size() == x.data_size() && axes.size() == x.ndim() && + x.flags().contiguous) { + return ContiguousAllReduce; + } + + // Row contiguous input so the output is row contiguous + if (x.flags().row_contiguous) { + // Merge consecutive axes + std::vector shape = {x.shape(axes[0])}; + std::vector strides = {x.strides()[axes[0]]}; + for (int i = 1; i < axes.size(); i++) { + if (axes[i] - 1 == axes[i - 1]) { + shape.back() *= x.shape(axes[i]); + strides.back() = x.strides()[axes[i]]; + } else { + shape.push_back(x.shape(axes[i])); + strides.push_back(x.strides()[axes[i]]); + } + } + + if (strides.back() == 1) { + return ReductionPlan(ContiguousReduce, shape, strides); + } else if (strides.back() > 1) { + return ReductionPlan(ContiguousStridedReduce, shape, strides); + } + } + + // Let's check if we can optimize our access patterns + // + // 1. We have a reduction axis with stride 1. Simply call + // GeneralContiguousReduce and be done with it. + // 2. We have transpositions and we are not reducing over the axis with + // stride 1. However, we are reducing over an axis where everything is + // contiguous in memory to the right of that axis. We can call strided + // reduce and be done with it. + // 2. We have weird transpositions and expands. Copy the strides to the + // output, then call strided reduce. + + // Sort reduction axes by stride in order to merge them and figure out if we + // have a contiguous reduction. + std::vector> reductions; + for (auto a : axes) { + reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); + } + std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { + return a.second > b.second; + }); + // Extract the two smallest and try to merge them in case the contiguous + // reduction can be bigger than just the last axis. + for (int i = reductions.size() - 1; i >= 1; i--) { + auto a = reductions[i]; + auto b = reductions[i - 1]; + + // b.stride = a.shape * a.stride then a and b are contiguous + if (b.second == a.first * a.second) { + reductions.erase(reductions.begin() + i); + reductions[i - 1] = std::make_pair(a.first * b.first, a.second); + } + } + + std::vector shape; + std::vector strides; + for (auto r : reductions) { + shape.push_back(r.first); + strides.push_back(r.second); + } + + // We can call the contiguous reduction op for every weird way the input is + // structured in the rest of the axes. + if (strides.back() == 1) { + return ReductionPlan(GeneralContiguousReduce, shape, strides); + } + + // Delegate to the general strided reduction op if the axes after + // strides.back() are contiguous. + if (strides.back() > 1) { + int size = 1; + for (int i = x.ndim() - 1; i >= 0; i--) { + if (axes.back() == i) { + continue; + } + if (x.strides()[i] != size) { + break; + } + size *= x.shape(i); + } + if (size >= strides.back()) { + return ReductionPlan(GeneralStridedReduce, shape, strides); + } + } + + return ReductionPlan(GeneralReduce, shape, strides); +} + +} // namespace mlx::core diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index c82e2a457..50a30da43 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -6,5 +6,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp )