From 1156c84e86c02d209373c545a47e8da3e1be246b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Feb 2025 15:58:02 -0800 Subject: [PATCH] Refactor common into cpu specific and truly common (#1817) * refactor * fix extension example * fix no-cpu --- examples/extensions/axpby/axpby.cpp | 1 + mlx/CMakeLists.txt | 4 +- mlx/backend/common/CMakeLists.txt | 88 +-- mlx/backend/common/binary.h | 366 +----------- mlx/backend/common/copy.h | 14 - mlx/backend/common/erf.cpp | 40 -- mlx/backend/common/reduce.cpp | 486 ++++----------- mlx/backend/common/reduce.h | 186 ------ mlx/backend/common/reduce_utils.cpp | 147 ----- mlx/backend/common/simd/simd.h | 4 - mlx/backend/common/simd/type.h | 7 - mlx/backend/common/ternary.h | 153 +---- mlx/backend/cpu/CMakeLists.txt | 81 +++ mlx/backend/{common => cpu}/arange.h | 0 mlx/backend/{common => cpu}/arg_reduce.cpp | 2 +- mlx/backend/{common => cpu}/binary.cpp | 6 +- mlx/backend/cpu/binary.h | 370 ++++++++++++ mlx/backend/{common => cpu}/binary_ops.h | 2 +- mlx/backend/{common => cpu}/binary_two.h | 2 +- mlx/backend/{common => cpu}/cholesky.cpp | 4 +- .../compiled_cpu.cpp => cpu/compiled.cpp} | 4 +- .../{common => cpu}/compiled_preamble.h | 4 +- mlx/backend/{common => cpu}/conv.cpp | 4 +- mlx/backend/{common => cpu}/copy.cpp | 4 +- mlx/backend/cpu/copy.h | 24 + mlx/backend/{common => cpu}/eigh.cpp | 4 +- mlx/backend/{common => cpu}/fft.cpp | 0 mlx/backend/{common => cpu}/gemm.h | 0 mlx/backend/{common => cpu}/gemms/bnns.cpp | 2 +- mlx/backend/{common => cpu}/gemms/cblas.cpp | 4 +- mlx/backend/{common => cpu}/gemms/no_bf16.cpp | 2 +- mlx/backend/{common => cpu}/gemms/no_fp16.cpp | 2 +- mlx/backend/{common => cpu}/hadamard.cpp | 2 +- mlx/backend/{common => cpu}/indexing.cpp | 2 +- mlx/backend/{common => cpu}/inverse.cpp | 4 +- mlx/backend/{common => cpu}/jit_compiler.cpp | 2 +- mlx/backend/{common => cpu}/jit_compiler.h | 0 mlx/backend/{common => cpu}/lapack.h | 0 .../make_compiled_preamble.ps1 | 2 +- .../{common => cpu}/make_compiled_preamble.sh | 2 +- mlx/backend/{common => cpu}/masked_mm.cpp | 4 +- mlx/backend/{common => cpu}/matmul.cpp | 4 +- mlx/backend/{common => cpu}/primitives.cpp | 6 +- mlx/backend/{common => cpu}/qrf.cpp | 4 +- mlx/backend/{common => cpu}/quantized.cpp | 4 +- mlx/backend/cpu/reduce.cpp | 552 ++++++++++++++++++ mlx/backend/{common => cpu}/scan.cpp | 4 +- mlx/backend/{common => cpu}/select.cpp | 4 +- .../simd/accelerate_fp16_simd.h | 4 +- .../{common => cpu}/simd/accelerate_simd.h | 4 +- mlx/backend/{common => cpu}/simd/base_simd.h | 0 mlx/backend/{common => cpu}/simd/math.h | 2 +- .../{common => cpu}/simd/neon_fp16_simd.h | 2 +- mlx/backend/cpu/simd/simd.h | 4 + mlx/backend/cpu/simd/type.h | 7 + mlx/backend/cpu/slicing.h | 21 + mlx/backend/{common => cpu}/softmax.cpp | 4 +- mlx/backend/{common => cpu}/sort.cpp | 2 +- mlx/backend/{common => cpu}/svd.cpp | 4 +- mlx/backend/cpu/ternary.h | 157 +++++ mlx/backend/{common => cpu}/threefry.cpp | 2 +- mlx/backend/{common => cpu}/threefry.h | 0 mlx/backend/{common => cpu}/unary.cpp | 4 +- mlx/backend/{common => cpu}/unary.h | 6 +- mlx/backend/{common => cpu}/unary_ops.h | 2 +- mlx/backend/metal/copy.cpp | 1 + mlx/backend/metal/fft.cpp | 1 + mlx/backend/metal/matmul.cpp | 1 + mlx/backend/no_cpu/CMakeLists.txt | 12 +- .../compiled.cpp} | 3 +- mlx/distributed/mpi/mpi.cpp | 2 +- mlx/distributed/ring/ring.cpp | 2 +- 72 files changed, 1426 insertions(+), 1434 deletions(-) delete mode 100644 mlx/backend/common/erf.cpp delete mode 100644 mlx/backend/common/reduce_utils.cpp delete mode 100644 mlx/backend/common/simd/simd.h delete mode 100644 mlx/backend/common/simd/type.h create mode 100644 mlx/backend/cpu/CMakeLists.txt rename mlx/backend/{common => cpu}/arange.h (100%) rename mlx/backend/{common => cpu}/arg_reduce.cpp (98%) rename mlx/backend/{common => cpu}/binary.cpp (98%) create mode 100644 mlx/backend/cpu/binary.h rename mlx/backend/{common => cpu}/binary_ops.h (98%) rename mlx/backend/{common => cpu}/binary_two.h (99%) rename mlx/backend/{common => cpu}/cholesky.cpp (96%) rename mlx/backend/{common/compiled_cpu.cpp => cpu/compiled.cpp} (99%) rename mlx/backend/{common => cpu}/compiled_preamble.h (69%) rename mlx/backend/{common => cpu}/conv.cpp (99%) rename mlx/backend/{common => cpu}/copy.cpp (99%) create mode 100644 mlx/backend/cpu/copy.h rename mlx/backend/{common => cpu}/eigh.cpp (97%) rename mlx/backend/{common => cpu}/fft.cpp (100%) rename mlx/backend/{common => cpu}/gemm.h (100%) rename mlx/backend/{common => cpu}/gemms/bnns.cpp (99%) rename mlx/backend/{common => cpu}/gemms/cblas.cpp (92%) rename mlx/backend/{common => cpu}/gemms/no_bf16.cpp (89%) rename mlx/backend/{common => cpu}/gemms/no_fp16.cpp (89%) rename mlx/backend/{common => cpu}/hadamard.cpp (98%) rename mlx/backend/{common => cpu}/indexing.cpp (99%) rename mlx/backend/{common => cpu}/inverse.cpp (97%) rename mlx/backend/{common => cpu}/jit_compiler.cpp (98%) rename mlx/backend/{common => cpu}/jit_compiler.h (100%) rename mlx/backend/{common => cpu}/lapack.h (100%) rename mlx/backend/{common => cpu}/make_compiled_preamble.ps1 (97%) rename mlx/backend/{common => cpu}/make_compiled_preamble.sh (84%) rename mlx/backend/{common => cpu}/masked_mm.cpp (99%) rename mlx/backend/{common => cpu}/matmul.cpp (96%) rename mlx/backend/{common => cpu}/primitives.cpp (99%) rename mlx/backend/{common => cpu}/qrf.cpp (98%) rename mlx/backend/{common => cpu}/quantized.cpp (99%) create mode 100644 mlx/backend/cpu/reduce.cpp rename mlx/backend/{common => cpu}/scan.cpp (99%) rename mlx/backend/{common => cpu}/select.cpp (95%) rename mlx/backend/{common => cpu}/simd/accelerate_fp16_simd.h (94%) rename mlx/backend/{common => cpu}/simd/accelerate_simd.h (98%) rename mlx/backend/{common => cpu}/simd/base_simd.h (100%) rename mlx/backend/{common => cpu}/simd/math.h (99%) rename mlx/backend/{common => cpu}/simd/neon_fp16_simd.h (99%) create mode 100644 mlx/backend/cpu/simd/simd.h create mode 100644 mlx/backend/cpu/simd/type.h create mode 100644 mlx/backend/cpu/slicing.h rename mlx/backend/{common => cpu}/softmax.cpp (98%) rename mlx/backend/{common => cpu}/sort.cpp (99%) rename mlx/backend/{common => cpu}/svd.cpp (98%) create mode 100644 mlx/backend/cpu/ternary.h rename mlx/backend/{common => cpu}/threefry.cpp (95%) rename mlx/backend/{common => cpu}/threefry.h (100%) rename mlx/backend/{common => cpu}/unary.cpp (98%) rename mlx/backend/{common => cpu}/unary.h (97%) rename mlx/backend/{common => cpu}/unary_ops.h (98%) rename mlx/backend/{common/compiled_nocpu.cpp => no_cpu/compiled.cpp} (91%) diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 1a5d8c1c9..aa312bb3a 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/copy.h" #include "mlx/utils.h" #include "axpby/axpby.h" diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index b0b7f0bbb..5f5f81c96 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -29,8 +29,10 @@ if(WIN32) set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) endif() +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) + if(MLX_BUILD_CPU) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) endif() diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 6d61c7ff1..82e6eef84 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,88 +1,8 @@ -if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(COMPILER ${CMAKE_C_COMPILER}) - set(CLANG TRUE) -else() - set(COMPILER ${CMAKE_CXX_COMPILER}) -endif() - -set(COMPILE_DEPS - ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h - ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h - ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h - ${PROJECT_SOURCE_DIR}/mlx/types/complex.h - simd/simd.h - simd/base_simd.h - simd/math.h - simd/type.h - unary_ops.h - binary_ops.h) - -if(MSVC) - set(SHELL_EXT ps1) - set(SHELL_CMD powershell -ExecutionPolicy Bypass -File) -else() - set(SHELL_EXT sh) - set(SHELL_CMD bash) -endif() - -add_custom_command( - OUTPUT compiled_preamble.cpp - COMMAND - ${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT} - ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} - ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} - DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h - ${COMPILE_DEPS}) - -add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp) - -add_dependencies(mlx cpu_compiled_preamble) - target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp - ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) - -if(MLX_BUILD_ACCELERATE) - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) -else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) -endif() - -if(IOS) - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp) -else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp) -endif() + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index e28db35e1..9a8c10951 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -1,18 +1,13 @@ // Copyright © 2023 Apple Inc. #pragma once -#include #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/common/simd/simd.h" - namespace mlx::core { -namespace { - enum class BinaryOpType { ScalarScalar, ScalarVector, @@ -21,7 +16,7 @@ enum class BinaryOpType { General, }; -BinaryOpType get_binary_op_type(const array& a, const array& b) { +inline BinaryOpType get_binary_op_type(const array& a, const array& b) { BinaryOpType bopt; if (a.data_size() == 1 && b.data_size() == 1) { bopt = BinaryOpType::ScalarScalar; @@ -39,7 +34,7 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) { return bopt; } -void set_binary_op_output_data( +inline void set_binary_op_output_data( const array& a, const array& b, array& out, @@ -124,361 +119,4 @@ void set_binary_op_output_data( } } -template -struct VectorScalar { - Op op; - - VectorScalar(Op op_) : op(op_) {} - - template - void operator()(const T* a, const T* b, U* dst, int size) { - T scalar = *b; - constexpr int N = simd::max_size; - while (size >= N) { - simd::store(dst, op(simd::load(a), simd::Simd(scalar))); - dst += N; - a += N; - size -= N; - } - while (size-- > 0) { - *dst = op(*a, scalar); - dst++; - a++; - } - } -}; - -template -struct ScalarVector { - Op op; - - ScalarVector(Op op_) : op(op_) {} - - template - void operator()(const T* a, const T* b, U* dst, int size) { - T scalar = *a; - constexpr int N = simd::max_size; - while (size >= N) { - simd::store(dst, op(simd::Simd(scalar), simd::load(b))); - dst += N; - b += N; - size -= N; - } - while (size-- > 0) { - *dst = op(scalar, *b); - dst++; - b++; - } - } -}; - -template -struct VectorVector { - Op op; - - VectorVector(Op op_) : op(op_) {} - - template - void operator()(const T* a, const T* b, U* dst, int size) { - constexpr int N = simd::max_size; - while (size >= N) { - simd::store(dst, op(simd::load(a), simd::load(b))); - dst += N; - a += N; - b += N; - size -= N; - } - while (size-- > 0) { - *dst = op(*a, *b); - dst++; - a++; - b++; - } - } -}; - -template -void binary_op_dims( - const T* a, - const T* b, - U* out, - Op op, - const Shape& shape, - const Strides& a_strides, - const Strides& b_strides, - const Strides& out_strides, - int axis) { - auto stride_a = a_strides[axis]; - auto stride_b = b_strides[axis]; - auto stride_out = out_strides[axis]; - auto N = shape[axis]; - - for (int i = 0; i < N; i++) { - if constexpr (D > 1) { - binary_op_dims( - a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1); - } else { - if constexpr (Strided) { - op(a, b, out, stride_out); - } else { - *out = op(*a, *b); - } - } - out += stride_out; - a += stride_a; - b += stride_b; - } -} - -template -void binary_op_dispatch_dims( - const array& a, - const array& b, - array& out, - Op op, - int dim, - const Shape& shape, - const Strides& a_strides, - const Strides& b_strides, - const Strides& out_strides) { - const T* a_ptr = a.data(); - const T* b_ptr = b.data(); - U* out_ptr = out.data(); - switch (dim) { - case 1: - binary_op_dims( - a_ptr, - b_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); - return; - case 2: - binary_op_dims( - a_ptr, - b_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); - return; - case 3: - binary_op_dims( - a_ptr, - b_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - out_strides, - 0); - return; - } - - ContiguousIterator a_it(shape, a_strides, dim - 3); - ContiguousIterator b_it(shape, b_strides, dim - 3); - auto stride = out_strides[dim - 4]; - for (int64_t elem = 0; elem < a.size(); elem += stride) { - binary_op_dims( - a_ptr + a_it.loc, - b_ptr + b_it.loc, - out_ptr + elem, - op, - shape, - a_strides, - b_strides, - out_strides, - dim - 3); - a_it.step(); - b_it.step(); - } -} - -template -void binary_op(const array& a, const array& b, array& out, Op op) { - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt); - - // The full computation is scalar scalar so call the base op once - if (bopt == BinaryOpType::ScalarScalar) { - *(out.data()) = op(*a.data(), *b.data()); - return; - } - - // The full computation is scalar vector so delegate to the op - if (bopt == BinaryOpType::ScalarVector) { - ScalarVector{op}(a.data(), b.data(), out.data(), b.data_size()); - return; - } - - // The full computation is vector scalar so delegate to the op - if (bopt == BinaryOpType::VectorScalar) { - VectorScalar{op}(a.data(), b.data(), out.data(), a.data_size()); - return; - } - - // The full computation is vector vector so delegate to the op - if (bopt == BinaryOpType::VectorVector) { - VectorVector{op}(a.data(), b.data(), out.data(), out.size()); - return; - } - - // General computation so let's try to optimize - auto [new_shape, new_strides] = collapse_contiguous_dims( - a.shape(), {a.strides(), b.strides(), out.strides()}); - const auto& a_strides = new_strides[0]; - const auto& b_strides = new_strides[1]; - const auto& strides = new_strides[2]; - - // Get the left-most dim such that the array is row contiguous after - auto leftmost_rc_dim = [&strides](const auto& arr_strides) { - int d = arr_strides.size() - 1; - for (; d >= 0 && arr_strides[d] == strides[d]; d--) { - } - return d + 1; - }; - auto a_rc_dim = leftmost_rc_dim(a_strides); - auto b_rc_dim = leftmost_rc_dim(b_strides); - - // Get the left-most dim such that the array is a broadcasted "scalar" after - auto leftmost_s_dim = [](const auto& arr_strides) { - int d = arr_strides.size() - 1; - for (; d >= 0 && arr_strides[d] == 0; d--) { - } - return d + 1; - }; - auto a_s_dim = leftmost_s_dim(a_strides); - auto b_s_dim = leftmost_s_dim(b_strides); - - auto ndim = new_shape.size(); - - // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous - int dim = ndim; - if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::VectorVector; - dim = d; - // Case 2: LxM and Fx1 where L and F are broadcastable and M is row - // contiguous - } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = BinaryOpType::VectorScalar; - dim = d; - // Case 3: Lx1 and FxM where L and F are broadcastable and M is row - // contiguous - } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = BinaryOpType::ScalarVector; - dim = d; - } - - // Can be sure dim > 0 since otherwise we would have used one of the fully - // contiguous methods above. Except for the case that the flags do not - // correspond to the underlying contiguity. - if (dim == 0 || strides[dim - 1] < 16) { - bopt = BinaryOpType::General; - dim = ndim; - } - - switch (bopt) { - case BinaryOpType::VectorVector: - binary_op_dispatch_dims( - a, - b, - out, - VectorVector{op}, - dim, - new_shape, - a_strides, - b_strides, - strides); - break; - case BinaryOpType::VectorScalar: - binary_op_dispatch_dims( - a, - b, - out, - VectorScalar{op}, - dim, - new_shape, - a_strides, - b_strides, - strides); - break; - case BinaryOpType::ScalarVector: - binary_op_dispatch_dims( - a, - b, - out, - ScalarVector{op}, - dim, - new_shape, - a_strides, - b_strides, - strides); - break; - default: - binary_op_dispatch_dims( - a, b, out, op, dim, new_shape, a_strides, b_strides, strides); - break; - } -} - -template -void binary_op(const array& a, const array& b, array& out, Op op) { - binary_op(a, b, out, op); -} - -template -void binary(const array& a, const array& b, array& out, Op op) { - switch (out.dtype()) { - case bool_: - binary_op(a, b, out, op); - break; - case uint8: - binary_op(a, b, out, op); - break; - case uint16: - binary_op(a, b, out, op); - break; - case uint32: - binary_op(a, b, out, op); - break; - case uint64: - binary_op(a, b, out, op); - break; - case int8: - binary_op(a, b, out, op); - break; - case int16: - binary_op(a, b, out, op); - break; - case int32: - binary_op(a, b, out, op); - break; - case int64: - binary_op(a, b, out, op); - break; - case float16: - binary_op(a, b, out, op); - break; - case float32: - binary_op(a, b, out, op); - break; - case bfloat16: - binary_op(a, b, out, op); - break; - case complex64: - binary_op(a, b, out, op); - break; - } -} - -} // namespace - } // namespace mlx::core diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index 351790c02..b05967638 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -3,7 +3,6 @@ #pragma once #include "mlx/array.h" -#include "mlx/backend/common/utils.h" namespace mlx::core { @@ -23,17 +22,4 @@ enum class CopyType { GeneralGeneral }; -void copy(const array& src, array& dst, CopyType ctype); -void copy_inplace(const array& src, array& dst, CopyType ctype); - -void copy_inplace( - const array& src, - array& dst, - const Shape& data_shape, - const Strides& i_strides, - const Strides& o_strides, - int64_t i_offset, - int64_t o_offset, - CopyType ctype); - } // namespace mlx::core diff --git a/mlx/backend/common/erf.cpp b/mlx/backend/common/erf.cpp deleted file mode 100644 index 83769d078..000000000 --- a/mlx/backend/common/erf.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include - -namespace mlx::core { - -/* Approximation to the inverse error function. - * Based on code from: - * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 - */ -float erfinv(float a) { - auto t = std::fma(a, 0.0f - a, 1.0f); - t = std::log(t); - float p; - if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 - p = 3.03697567e-10f; // 0x1.4deb44p-32 - p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 - p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 - p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 - p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 - p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 - p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 - p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 - p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 - } else { // maximum ulp error = 2.35002 - p = 5.43877832e-9f; // 0x1.75c000p-28 - p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 - p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 - p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 - p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 - p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 - p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 - p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 - p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 - p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 - } - return a * p; -} - -} // namespace mlx::core diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 71c72e2ea..5c7f63b75 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -1,377 +1,147 @@ -// Copyright © 2023 Apple Inc. - -#include -#include -#include +// Copyright © 2024 Apple Inc. #include "mlx/backend/common/reduce.h" -#include "mlx/backend/common/simd/simd.h" -#include "mlx/primitives.h" namespace mlx::core { -namespace { - -template -struct Limits { - static const U max; - static const U min; -}; - -#define instantiate_default_limit(type) \ - template <> \ - struct Limits { \ - static constexpr type max = std::numeric_limits::max(); \ - static constexpr type min = std::numeric_limits::min(); \ - }; - -instantiate_default_limit(uint8_t); -instantiate_default_limit(uint16_t); -instantiate_default_limit(uint32_t); -instantiate_default_limit(uint64_t); -instantiate_default_limit(int8_t); -instantiate_default_limit(int16_t); -instantiate_default_limit(int32_t); -instantiate_default_limit(int64_t); - -#define instantiate_float_limit(type) \ - template <> \ - struct Limits { \ - static const type max; \ - static const type min; \ - }; - -instantiate_float_limit(float16_t); -instantiate_float_limit(bfloat16_t); -instantiate_float_limit(float); -instantiate_float_limit(complex64_t); - -template <> -struct Limits { - static constexpr bool max = true; - static constexpr bool min = false; -}; - -const float Limits::max = std::numeric_limits::infinity(); -const float Limits::min = -std::numeric_limits::infinity(); -const bfloat16_t Limits::max = - std::numeric_limits::infinity(); -const bfloat16_t Limits::min = - -std::numeric_limits::infinity(); -const float16_t Limits::max = std::numeric_limits::infinity(); -const float16_t Limits::min = - -std::numeric_limits::infinity(); -const complex64_t Limits::max = - std::numeric_limits::infinity(); -const complex64_t Limits::min = - -std::numeric_limits::infinity(); - -struct AndReduce { - template - bool operator()(bool x, T y) { - return x & (y != 0); - } - - bool operator()(bool x, bool y) { - return x & y; - } - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return x & (y != 0); - }; - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return x & y; - }; - - template - bool operator()(simd::Simd x) { - return simd::all(x); - }; -}; - -struct OrReduce { - template - bool operator()(bool x, T y) { - return x | (y != 0); - } - - bool operator()(bool x, bool y) { - return x | y; - } - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return x | (y != 0); - }; - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return x | y; - }; - - template - bool operator()(simd::Simd x) { - return simd::any(x); - }; -}; - -struct MaxReduce { - template - T operator()(T y, T x) { - return (*this)(simd::Simd(x), simd::Simd(y)).value; - }; - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return simd::maximum(x, y); - }; - - template - T operator()(simd::Simd x) { - return simd::max(x); - }; -}; - -struct MinReduce { - template - T operator()(T y, T x) { - return (*this)(simd::Simd(x), simd::Simd(y)).value; - }; - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return simd::minimum(x, y); - }; - - template - T operator()(simd::Simd x) { - return simd::min(x); - }; -}; - -struct SumReduce { - template - U operator()(U y, T x) { - return x + y; - }; - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return y + x; - }; - - template - T operator()(simd::Simd x) { - return simd::sum(x); - }; -}; - -struct ProdReduce { - template - U operator()(U y, T x) { - return x * y; - }; - - template - simd::Simd operator()(simd::Simd y, simd::Simd x) { - return x * y; - }; - - template - T operator()(simd::Simd x) { - return simd::prod(x); - }; -}; - -template -void reduce_dispatch_and_or( - const array& in, - array& out, - Reduce::ReduceType rtype, +std::pair shapes_without_reduction_axes( + const array& x, const std::vector& axes) { - if (rtype == Reduce::And) { - reduction_op(in, out, axes, true, AndReduce()); - } else { - reduction_op(in, out, axes, false, OrReduce()); + auto shape = x.shape(); + auto strides = x.strides(); + + for (int i = axes.size() - 1; i >= 0; i--) { + int a = axes[i]; + shape.erase(shape.begin() + a); + strides.erase(strides.begin() + a); } + + return std::make_pair(shape, strides); } -template -void reduce_dispatch_sum_prod( - const array& in, - array& out, - Reduce::ReduceType rtype, - const std::vector& axes) { - if (rtype == Reduce::Sum) { - if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 0, SumReduce()); - } else { - reduction_op(in, out, axes, 0, SumReduce()); +ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { + // The data is all there and we are reducing over everything + if (x.size() == x.data_size() && axes.size() == x.ndim() && + x.flags().contiguous) { + return ContiguousAllReduce; + } + + // Row contiguous input so the output is row contiguous + if (x.flags().row_contiguous) { + // Merge consecutive axes + Shape shape = {x.shape(axes[0])}; + Strides strides = {x.strides()[axes[0]]}; + for (int i = 1; i < axes.size(); i++) { + if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { + shape.back() *= x.shape(axes[i]); + strides.back() = x.strides()[axes[i]]; + } else { + shape.push_back(x.shape(axes[i])); + strides.push_back(x.strides()[axes[i]]); + } } - } else { - if constexpr (std::is_integral_v && sizeof(InT) <= 4) { - reduction_op(in, out, axes, 1, ProdReduce()); - } else { - reduction_op(in, out, axes, 1, ProdReduce()); + + // Remove singleton axes from the plan + for (int i = shape.size() - 1; i >= 0; i--) { + if (shape[i] == 1) { + shape.erase(shape.begin() + i); + strides.erase(strides.begin() + i); + } + } + + if (strides.back() == 1) { + return ReductionPlan(ContiguousReduce, shape, strides); + } else if (strides.back() > 1) { + return ReductionPlan(ContiguousStridedReduce, shape, strides); } } -} -template -void reduce_dispatch_min_max( - const array& in, - array& out, - Reduce::ReduceType rtype, - const std::vector& axes) { - if (rtype == Reduce::Max) { - auto init = Limits::min; - reduction_op(in, out, axes, init, MaxReduce()); - } else { - auto init = Limits::max; - reduction_op(in, out, axes, init, MinReduce()); - } -} + // Let's check if we can optimize our access patterns + // + // 1. We have a reduction axis with stride 1. Simply call + // GeneralContiguousReduce and be done with it. + // 2. We have transpositions and we are not reducing over the axis with + // stride 1. However, we are reducing over an axis where everything is + // contiguous in memory to the right of that axis. We can call strided + // reduce and be done with it. + // 2. We have weird transpositions and expands. Copy the strides to the + // output, then call strided reduce. -} // namespace - -void nd_loop( - std::function callback, - const Shape& shape, - const Strides& strides) { - std::function loop_inner; - loop_inner = [&](int dim, int offset) { - if (dim < shape.size() - 1) { - auto size = shape[dim]; - auto stride = strides[dim]; - for (int i = 0; i < size; i++) { - loop_inner(dim + 1, offset + i * stride); - } - } else { - auto size = shape[dim]; - auto stride = strides[dim]; - for (int i = 0; i < size; i++) { - callback(offset + i * stride); - } - } - }; - loop_inner(0, 0); -} - -void Reduce::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - switch (reduce_type_) { - case Reduce::And: - case Reduce::Or: { - switch (in.dtype()) { - case bool_: - case uint8: - case int8: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); - break; - case int16: - case uint16: - case float16: - case bfloat16: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); - break; - case uint32: - case int32: - case float32: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); - break; - case uint64: - case int64: - case complex64: - reduce_dispatch_and_or(in, out, reduce_type_, axes_); - break; - } - break; - } - case Reduce::Sum: - case Reduce::Prod: { - switch (in.dtype()) { - case bool_: - case uint8: - case int8: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - case int16: - case uint16: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - case int32: - case uint32: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - case int64: - case uint64: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - case float16: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - case bfloat16: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - case float32: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - case complex64: - reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); - break; - } - break; - } - case Reduce::Max: - case Reduce::Min: { - switch (in.dtype()) { - case bool_: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case uint8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case uint16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case uint32: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case uint64: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case int8: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case int16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case int32: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case int64: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case float16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case float32: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case bfloat16: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - case complex64: - reduce_dispatch_min_max(in, out, reduce_type_, axes_); - break; - } - break; + // Sort reduction axes by stride in order to merge them and figure out if we + // have a contiguous reduction. + std::vector> reductions; + for (auto a : axes) { + if (x.shape(a) > 1) { + reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); } } + std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { + bool a_is_zero = a.second == 0; + bool b_is_zero = b.second == 0; + return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second; + }); + // Extract the two smallest and try to merge them in case the contiguous + // reduction can be bigger than just the last axis. + for (int i = reductions.size() - 1; i >= 1; i--) { + auto a = reductions[i]; + auto b = reductions[i - 1]; + + // b.stride = a.shape * a.stride then a and b are contiguous + if (b.second == a.first * a.second) { + reductions.erase(reductions.begin() + i); + reductions[i - 1] = std::make_pair(a.first * b.first, a.second); + } + } + + Shape shape; + Strides strides; + for (auto r : reductions) { + shape.push_back(r.first); + strides.push_back(r.second); + } + + // We can call the contiguous reduction op for every weird way the input is + // structured in the rest of the axes. + if (strides.back() == 1) { + return ReductionPlan(GeneralContiguousReduce, shape, strides); + } + + // Delegate to the general strided reduction op if the axes after + // strides.back() are contiguous. + if (strides.back() > 1) { + int64_t size = 1; + bool have_expand = false; + for (int i = x.ndim() - 1; i >= 0; i--) { + if (axes.back() == i) { + continue; + } + + auto stride_i = x.strides()[i]; + auto shape_i = x.shape(i); + if (stride_i == 0) { + if (shape_i == 1) { + continue; + } + + have_expand = true; + break; + } + + if (stride_i != size && shape_i != 1) { + break; + } + size *= shape_i; + } + // In the case of an expanded dimension we are being conservative and + // require the smallest reduction stride to be smaller than the maximum row + // contiguous size. The reason is that we can't easily know if the reduced + // axis is before or after an expanded dimension. + if (size > strides.back() || (size == strides.back() && !have_expand)) { + return ReductionPlan(GeneralStridedReduce, shape, strides); + } + } + + return ReductionPlan(GeneralReduce, shape, strides); } } // namespace mlx::core diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index b9e44ddc8..ddb5c3492 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -2,7 +2,6 @@ #pragma once -#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" namespace mlx::core { @@ -49,193 +48,8 @@ struct ReductionPlan { ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); -// Helper for the ndimensional strided loop -// Should this be in utils? -void nd_loop( - std::function callback, - const Shape& shape, - const Strides& strides); - std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes); -template -void strided_reduce( - const T* x, - U* accumulator, - int size, - size_t stride, - Op op) { - constexpr int N = std::min(simd::max_size, simd::max_size); - for (int i = 0; i < size; i++) { - U* moving_accumulator = accumulator; - auto s = stride; - while (s >= N) { - auto acc = simd::load(moving_accumulator); - auto v = simd::Simd(simd::load(x)); - simd::store(moving_accumulator, op(acc, v)); - moving_accumulator += N; - x += N; - s -= N; - } - while (s-- > 0) { - *moving_accumulator = op(*moving_accumulator, *x); - moving_accumulator++; - x++; - } - } -}; - -template -void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) { - constexpr int N = std::min(simd::max_size, simd::max_size); - simd::Simd accumulator_v(init); - while (size >= N) { - accumulator_v = op(accumulator_v, simd::Simd(simd::load(x))); - x += N; - size -= N; - } - *accumulator = op(*accumulator, op(accumulator_v)); - while (size-- > 0) { - *accumulator = op(*accumulator, *x); - x++; - } -} - -template -void reduction_op( - const array& x, - array& out, - const std::vector& axes, - U init, - Op op) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - ReductionPlan plan = get_reduction_plan(x, axes); - - if (plan.type == ContiguousAllReduce) { - U* out_ptr = out.data(); - *out_ptr = init; - contiguous_reduce(x.data(), out_ptr, x.size(), op, init); - return; - } - - if (plan.type == ContiguousReduce && plan.shape.size() == 1) { - int reduction_size = plan.shape[0]; - const T* x_ptr = x.data(); - U* out_ptr = out.data(); - for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) { - *out_ptr = init; - contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init); - } - return; - } - - if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) { - int reduction_size = plan.shape.back(); - plan.shape.pop_back(); - plan.strides.pop_back(); - const T* x_ptr = x.data(); - U* out_ptr = out.data(); - // Unrolling the following loop (and implementing it in order for - // ContiguousReduce) should hold extra performance boost. - auto [shape, strides] = shapes_without_reduction_axes(x, axes); - if (plan.shape.size() == 0) { - for (int i = 0; i < out.size(); i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - *out_ptr = init; - contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init); - } - } else { - for (int i = 0; i < out.size(); i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - *out_ptr = init; - nd_loop( - [&](int extra_offset) { - contiguous_reduce( - x_ptr + offset + extra_offset, - out_ptr, - reduction_size, - op, - init); - }, - plan.shape, - plan.strides); - } - } - return; - } - - if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) { - int reduction_size = plan.shape.back(); - size_t reduction_stride = plan.strides.back(); - plan.shape.pop_back(); - plan.strides.pop_back(); - const T* x_ptr = x.data(); - U* out_ptr = out.data(); - for (int i = 0; i < out.size(); i += reduction_stride) { - std::fill_n(out_ptr, reduction_stride, init); - strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op); - x_ptr += reduction_stride * reduction_size; - out_ptr += reduction_stride; - } - return; - } - - if (plan.type == GeneralStridedReduce || - plan.type == ContiguousStridedReduce) { - int reduction_size = plan.shape.back(); - size_t reduction_stride = plan.strides.back(); - plan.shape.pop_back(); - plan.strides.pop_back(); - const T* x_ptr = x.data(); - U* out_ptr = out.data(); - auto [shape, strides] = shapes_without_reduction_axes(x, axes); - if (plan.shape.size() == 0) { - for (int i = 0; i < out.size(); i += reduction_stride) { - int offset = elem_to_loc(i, shape, strides); - std::fill_n(out_ptr, reduction_stride, init); - strided_reduce( - x_ptr + offset, out_ptr, reduction_size, reduction_stride, op); - out_ptr += reduction_stride; - } - } else { - for (int i = 0; i < out.size(); i += reduction_stride) { - int offset = elem_to_loc(i, shape, strides); - std::fill_n(out_ptr, reduction_stride, init); - nd_loop( - [&](int extra_offset) { - strided_reduce( - x_ptr + offset + extra_offset, - out_ptr, - reduction_size, - reduction_stride, - op); - }, - plan.shape, - plan.strides); - out_ptr += reduction_stride; - } - } - return; - } - - if (plan.type == GeneralReduce) { - const T* x_ptr = x.data(); - U* out_ptr = out.data(); - auto [shape, strides] = shapes_without_reduction_axes(x, axes); - for (int i = 0; i < out.size(); i++, out_ptr++) { - int offset = elem_to_loc(i, shape, strides); - U val = init; - nd_loop( - [&](int extra_offset) { - val = op(val, *(x_ptr + offset + extra_offset)); - }, - plan.shape, - plan.strides); - *out_ptr = val; - } - } -} - } // namespace mlx::core diff --git a/mlx/backend/common/reduce_utils.cpp b/mlx/backend/common/reduce_utils.cpp deleted file mode 100644 index 5c7f63b75..000000000 --- a/mlx/backend/common/reduce_utils.cpp +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "mlx/backend/common/reduce.h" - -namespace mlx::core { - -std::pair shapes_without_reduction_axes( - const array& x, - const std::vector& axes) { - auto shape = x.shape(); - auto strides = x.strides(); - - for (int i = axes.size() - 1; i >= 0; i--) { - int a = axes[i]; - shape.erase(shape.begin() + a); - strides.erase(strides.begin() + a); - } - - return std::make_pair(shape, strides); -} - -ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { - // The data is all there and we are reducing over everything - if (x.size() == x.data_size() && axes.size() == x.ndim() && - x.flags().contiguous) { - return ContiguousAllReduce; - } - - // Row contiguous input so the output is row contiguous - if (x.flags().row_contiguous) { - // Merge consecutive axes - Shape shape = {x.shape(axes[0])}; - Strides strides = {x.strides()[axes[0]]}; - for (int i = 1; i < axes.size(); i++) { - if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { - shape.back() *= x.shape(axes[i]); - strides.back() = x.strides()[axes[i]]; - } else { - shape.push_back(x.shape(axes[i])); - strides.push_back(x.strides()[axes[i]]); - } - } - - // Remove singleton axes from the plan - for (int i = shape.size() - 1; i >= 0; i--) { - if (shape[i] == 1) { - shape.erase(shape.begin() + i); - strides.erase(strides.begin() + i); - } - } - - if (strides.back() == 1) { - return ReductionPlan(ContiguousReduce, shape, strides); - } else if (strides.back() > 1) { - return ReductionPlan(ContiguousStridedReduce, shape, strides); - } - } - - // Let's check if we can optimize our access patterns - // - // 1. We have a reduction axis with stride 1. Simply call - // GeneralContiguousReduce and be done with it. - // 2. We have transpositions and we are not reducing over the axis with - // stride 1. However, we are reducing over an axis where everything is - // contiguous in memory to the right of that axis. We can call strided - // reduce and be done with it. - // 2. We have weird transpositions and expands. Copy the strides to the - // output, then call strided reduce. - - // Sort reduction axes by stride in order to merge them and figure out if we - // have a contiguous reduction. - std::vector> reductions; - for (auto a : axes) { - if (x.shape(a) > 1) { - reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); - } - } - std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { - bool a_is_zero = a.second == 0; - bool b_is_zero = b.second == 0; - return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second; - }); - // Extract the two smallest and try to merge them in case the contiguous - // reduction can be bigger than just the last axis. - for (int i = reductions.size() - 1; i >= 1; i--) { - auto a = reductions[i]; - auto b = reductions[i - 1]; - - // b.stride = a.shape * a.stride then a and b are contiguous - if (b.second == a.first * a.second) { - reductions.erase(reductions.begin() + i); - reductions[i - 1] = std::make_pair(a.first * b.first, a.second); - } - } - - Shape shape; - Strides strides; - for (auto r : reductions) { - shape.push_back(r.first); - strides.push_back(r.second); - } - - // We can call the contiguous reduction op for every weird way the input is - // structured in the rest of the axes. - if (strides.back() == 1) { - return ReductionPlan(GeneralContiguousReduce, shape, strides); - } - - // Delegate to the general strided reduction op if the axes after - // strides.back() are contiguous. - if (strides.back() > 1) { - int64_t size = 1; - bool have_expand = false; - for (int i = x.ndim() - 1; i >= 0; i--) { - if (axes.back() == i) { - continue; - } - - auto stride_i = x.strides()[i]; - auto shape_i = x.shape(i); - if (stride_i == 0) { - if (shape_i == 1) { - continue; - } - - have_expand = true; - break; - } - - if (stride_i != size && shape_i != 1) { - break; - } - size *= shape_i; - } - // In the case of an expanded dimension we are being conservative and - // require the smallest reduction stride to be smaller than the maximum row - // contiguous size. The reason is that we can't easily know if the reduced - // axis is before or after an expanded dimension. - if (size > strides.back() || (size == strides.back() && !have_expand)) { - return ReductionPlan(GeneralStridedReduce, shape, strides); - } - } - - return ReductionPlan(GeneralReduce, shape, strides); -} - -} // namespace mlx::core diff --git a/mlx/backend/common/simd/simd.h b/mlx/backend/common/simd/simd.h deleted file mode 100644 index 4b356a9e5..000000000 --- a/mlx/backend/common/simd/simd.h +++ /dev/null @@ -1,4 +0,0 @@ -#pragma once - -#include "mlx/backend/common/simd/math.h" -#include "mlx/backend/common/simd/type.h" diff --git a/mlx/backend/common/simd/type.h b/mlx/backend/common/simd/type.h deleted file mode 100644 index 23b71a1cf..000000000 --- a/mlx/backend/common/simd/type.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -#include "mlx/backend/common/simd/base_simd.h" - -#ifdef MLX_USE_ACCELERATE -#include "mlx/backend/common/simd/accelerate_simd.h" -#endif diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index eaff8db00..fad7fa95d 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -7,8 +7,6 @@ namespace mlx::core { -namespace { - // TODO: Add support for more combinations of input types. enum class TernaryOpType { ScalarScalarScalar, @@ -16,7 +14,7 @@ enum class TernaryOpType { General, }; -TernaryOpType +inline TernaryOpType get_ternary_op_type(const array& a, const array& b, const array& c) { TernaryOpType topt; if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { @@ -33,7 +31,7 @@ get_ternary_op_type(const array& a, const array& b, const array& c) { return topt; } -void set_ternary_op_output_data( +inline void set_ternary_op_output_data( const array& a, const array& b, const array& c, @@ -76,152 +74,5 @@ void set_ternary_op_output_data( break; } } -template -void ternary_op_dims( - const T1* a, - const T2* b, - const T3* c, - U* out, - Op op, - const Shape& shape, - const Strides& a_strides, - const Strides& b_strides, - const Strides& c_strides, - const Strides& out_strides, - int axis) { - auto stride_a = a_strides[axis]; - auto stride_b = b_strides[axis]; - auto stride_c = c_strides[axis]; - auto stride_out = out_strides[axis]; - auto N = shape[axis]; - - for (int i = 0; i < N; i++) { - if constexpr (D > 1) { - ternary_op_dims( - a, - b, - c, - out, - op, - shape, - a_strides, - b_strides, - c_strides, - out_strides, - axis + 1); - } else { - *out = op(*a, *b, *c); - } - a += stride_a; - b += stride_b; - c += stride_c; - out += stride_out; - } -} - -template -void ternary_op_dispatch_dims( - const array& a, - const array& b, - const array& c, - array& out, - Op op) { - auto [shape, strides] = collapse_contiguous_dims( - a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); - const auto& a_strides = strides[0]; - const auto& b_strides = strides[1]; - const auto& c_strides = strides[2]; - const auto& out_strides = strides[3]; - - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - U* out_ptr = out.data(); - int ndim = shape.size(); - switch (ndim) { - case 1: - ternary_op_dims( - a_ptr, - b_ptr, - c_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - c_strides, - out_strides, - 0); - return; - case 2: - ternary_op_dims( - a_ptr, - b_ptr, - c_ptr, - out_ptr, - op, - shape, - a_strides, - b_strides, - c_strides, - out_strides, - 0); - return; - } - - ContiguousIterator a_it(shape, a_strides, ndim - 2); - ContiguousIterator b_it(shape, b_strides, ndim - 2); - ContiguousIterator c_it(shape, c_strides, ndim - 2); - auto stride = out_strides[ndim - 3]; - for (size_t elem = 0; elem < a.size(); elem += stride) { - ternary_op_dims( - a_ptr + a_it.loc, - b_ptr + b_it.loc, - c_ptr + c_it.loc, - out_ptr + elem, - op, - shape, - a_strides, - b_strides, - c_strides, - out_strides, - ndim - 2); - a_it.step(); - b_it.step(); - c_it.step(); - } -} - -template -void ternary_op( - const array& a, - const array& b, - const array& c, - array& out, - Op op) { - TernaryOpType topt = get_ternary_op_type(a, b, c); - set_ternary_op_output_data(a, b, c, out, topt); - - // The full computation is scalar-scalar-scalar so we call the base op once. - if (topt == TernaryOpType::ScalarScalarScalar) { - *(out.data()) = op(*a.data(), *b.data(), *c.data()); - } else if (topt == TernaryOpType::VectorVectorVector) { - const T1* a_ptr = a.data(); - const T2* b_ptr = b.data(); - const T3* c_ptr = c.data(); - U* out_ptr = out.data(); - for (size_t i = 0; i < out.size(); ++i) { - *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); - a_ptr++; - b_ptr++; - c_ptr++; - out_ptr++; - } - } else { - ternary_op_dispatch_dims(a, b, c, out, op); - } -} - -} // namespace } // namespace mlx::core diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt new file mode 100644 index 000000000..b98f3985c --- /dev/null +++ b/mlx/backend/cpu/CMakeLists.txt @@ -0,0 +1,81 @@ +if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(COMPILER ${CMAKE_C_COMPILER}) + set(CLANG TRUE) +else() + set(COMPILER ${CMAKE_CXX_COMPILER}) +endif() + +set(COMPILE_DEPS + ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h + ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h + ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h + ${PROJECT_SOURCE_DIR}/mlx/types/complex.h + simd/simd.h + simd/base_simd.h + simd/math.h + simd/type.h + unary_ops.h + binary_ops.h) + +if(MSVC) + set(SHELL_EXT ps1) + set(SHELL_CMD powershell -ExecutionPolicy Bypass -File) +else() + set(SHELL_EXT sh) + set(SHELL_CMD bash) +endif() + +add_custom_command( + OUTPUT compiled_preamble.cpp + COMMAND + ${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT} + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} + ${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR} + DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h + ${COMPILE_DEPS}) + +add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp) + +add_dependencies(mlx cpu_compiled_preamble) + +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) + +if(MLX_BUILD_ACCELERATE) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) +else() + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) +endif() + +if(IOS) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../no_cpu/compiled.cpp) +else() + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp) +endif() diff --git a/mlx/backend/common/arange.h b/mlx/backend/cpu/arange.h similarity index 100% rename from mlx/backend/common/arange.h rename to mlx/backend/cpu/arange.h diff --git a/mlx/backend/common/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp similarity index 98% rename from mlx/backend/common/arg_reduce.cpp rename to mlx/backend/cpu/arg_reduce.cpp index 4d66796e1..38eff29a1 100644 --- a/mlx/backend/common/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -2,8 +2,8 @@ #include +#include "mlx/backend/common/utils.h" #include "mlx/primitives.h" -#include "utils.h" namespace mlx::core { diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/cpu/binary.cpp similarity index 98% rename from mlx/backend/common/binary.cpp rename to mlx/backend/cpu/binary.cpp index 6178328d1..3fd0e63a5 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -5,9 +5,9 @@ #include #include "mlx/allocator.h" -#include "mlx/backend/common/binary.h" -#include "mlx/backend/common/binary_ops.h" -#include "mlx/backend/common/binary_two.h" +#include "mlx/backend/cpu/binary.h" +#include "mlx/backend/cpu/binary_ops.h" +#include "mlx/backend/cpu/binary_two.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/backend/cpu/binary.h b/mlx/backend/cpu/binary.h new file mode 100644 index 000000000..ab9fb2486 --- /dev/null +++ b/mlx/backend/cpu/binary.h @@ -0,0 +1,370 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/binary.h" +#include "mlx/backend/common/utils.h" + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core { + +template +struct VectorScalar { + Op op; + + VectorScalar(Op op_) : op(op_) {} + + template + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *b; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, op(simd::load(a), simd::Simd(scalar))); + dst += N; + a += N; + size -= N; + } + while (size-- > 0) { + *dst = op(*a, scalar); + dst++; + a++; + } + } +}; + +template +struct ScalarVector { + Op op; + + ScalarVector(Op op_) : op(op_) {} + + template + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *a; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, op(simd::Simd(scalar), simd::load(b))); + dst += N; + b += N; + size -= N; + } + while (size-- > 0) { + *dst = op(scalar, *b); + dst++; + b++; + } + } +}; + +template +struct VectorVector { + Op op; + + VectorVector(Op op_) : op(op_) {} + + template + void operator()(const T* a, const T* b, U* dst, int size) { + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, op(simd::load(a), simd::load(b))); + dst += N; + a += N; + b += N; + size -= N; + } + while (size-- > 0) { + *dst = op(*a, *b); + dst++; + a++; + b++; + } + } +}; + +template +void binary_op_dims( + const T* a, + const T* b, + U* out, + Op op, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + binary_op_dims( + a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1); + } else { + if constexpr (Strided) { + op(a, b, out, stride_out); + } else { + *out = op(*a, *b); + } + } + out += stride_out; + a += stride_a; + b += stride_b; + } +} + +template +void binary_op_dispatch_dims( + const array& a, + const array& b, + array& out, + Op op, + int dim, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides) { + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* out_ptr = out.data(); + switch (dim) { + case 1: + binary_op_dims( + a_ptr, + b_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + case 2: + binary_op_dims( + a_ptr, + b_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + case 3: + binary_op_dims( + a_ptr, + b_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, dim - 3); + ContiguousIterator b_it(shape, b_strides, dim - 3); + auto stride = out_strides[dim - 4]; + for (int64_t elem = 0; elem < a.size(); elem += stride) { + binary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + out_ptr + elem, + op, + shape, + a_strides, + b_strides, + out_strides, + dim - 3); + a_it.step(); + b_it.step(); + } +} + +template +void binary_op(const array& a, const array& b, array& out, Op op) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + // The full computation is scalar scalar so call the base op once + if (bopt == BinaryOpType::ScalarScalar) { + *(out.data()) = op(*a.data(), *b.data()); + return; + } + + // The full computation is scalar vector so delegate to the op + if (bopt == BinaryOpType::ScalarVector) { + ScalarVector{op}(a.data(), b.data(), out.data(), b.data_size()); + return; + } + + // The full computation is vector scalar so delegate to the op + if (bopt == BinaryOpType::VectorScalar) { + VectorScalar{op}(a.data(), b.data(), out.data(), a.data_size()); + return; + } + + // The full computation is vector vector so delegate to the op + if (bopt == BinaryOpType::VectorVector) { + VectorVector{op}(a.data(), b.data(), out.data(), out.size()); + return; + } + + // General computation so let's try to optimize + auto [new_shape, new_strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out.strides()}); + const auto& a_strides = new_strides[0]; + const auto& b_strides = new_strides[1]; + const auto& strides = new_strides[2]; + + // Get the left-most dim such that the array is row contiguous after + auto leftmost_rc_dim = [&strides](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == strides[d]; d--) { + } + return d + 1; + }; + auto a_rc_dim = leftmost_rc_dim(a_strides); + auto b_rc_dim = leftmost_rc_dim(b_strides); + + // Get the left-most dim such that the array is a broadcasted "scalar" after + auto leftmost_s_dim = [](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == 0; d--) { + } + return d + 1; + }; + auto a_s_dim = leftmost_s_dim(a_strides); + auto b_s_dim = leftmost_s_dim(b_strides); + + auto ndim = new_shape.size(); + + // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous + int dim = ndim; + if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::VectorVector; + dim = d; + // Case 2: LxM and Fx1 where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { + bopt = BinaryOpType::VectorScalar; + dim = d; + // Case 3: Lx1 and FxM where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::ScalarVector; + dim = d; + } + + // Can be sure dim > 0 since otherwise we would have used one of the fully + // contiguous methods above. Except for the case that the flags do not + // correspond to the underlying contiguity. + if (dim == 0 || strides[dim - 1] < 16) { + bopt = BinaryOpType::General; + dim = ndim; + } + + switch (bopt) { + case BinaryOpType::VectorVector: + binary_op_dispatch_dims( + a, + b, + out, + VectorVector{op}, + dim, + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::VectorScalar: + binary_op_dispatch_dims( + a, + b, + out, + VectorScalar{op}, + dim, + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::ScalarVector: + binary_op_dispatch_dims( + a, + b, + out, + ScalarVector{op}, + dim, + new_shape, + a_strides, + b_strides, + strides); + break; + default: + binary_op_dispatch_dims( + a, b, out, op, dim, new_shape, a_strides, b_strides, strides); + break; + } +} + +template +void binary_op(const array& a, const array& b, array& out, Op op) { + binary_op(a, b, out, op); +} + +template +void binary(const array& a, const array& b, array& out, Op op) { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, op); + break; + case uint8: + binary_op(a, b, out, op); + break; + case uint16: + binary_op(a, b, out, op); + break; + case uint32: + binary_op(a, b, out, op); + break; + case uint64: + binary_op(a, b, out, op); + break; + case int8: + binary_op(a, b, out, op); + break; + case int16: + binary_op(a, b, out, op); + break; + case int32: + binary_op(a, b, out, op); + break; + case int64: + binary_op(a, b, out, op); + break; + case float16: + binary_op(a, b, out, op); + break; + case float32: + binary_op(a, b, out, op); + break; + case bfloat16: + binary_op(a, b, out, op); + break; + case complex64: + binary_op(a, b, out, op); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/binary_ops.h b/mlx/backend/cpu/binary_ops.h similarity index 98% rename from mlx/backend/common/binary_ops.h rename to mlx/backend/cpu/binary_ops.h index fd10264f9..d50751ce3 100644 --- a/mlx/backend/common/binary_ops.h +++ b/mlx/backend/cpu/binary_ops.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/common/simd/simd.h" +#include "mlx/backend/cpu/simd/simd.h" namespace mlx::core::detail { diff --git a/mlx/backend/common/binary_two.h b/mlx/backend/cpu/binary_two.h similarity index 99% rename from mlx/backend/common/binary_two.h rename to mlx/backend/cpu/binary_two.h index 5088c06aa..6c106b904 100644 --- a/mlx/backend/common/binary_two.h +++ b/mlx/backend/cpu/binary_two.h @@ -2,8 +2,8 @@ #pragma once -#include "mlx/backend/common/binary.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/binary.h" namespace mlx::core { diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/cpu/cholesky.cpp similarity index 96% rename from mlx/backend/common/cholesky.cpp rename to mlx/backend/cpu/cholesky.cpp index ca09d9663..33668159a 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/cpu/cholesky.cpp @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/allocator.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/cpu/compiled.cpp similarity index 99% rename from mlx/backend/common/compiled_cpu.cpp rename to mlx/backend/cpu/compiled.cpp index e5c0156c8..905db82c8 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -10,8 +10,8 @@ #include #include "mlx/backend/common/compiled.h" -#include "mlx/backend/common/compiled_preamble.h" -#include "mlx/backend/common/jit_compiler.h" +#include "mlx/backend/cpu/compiled_preamble.h" +#include "mlx/backend/cpu/jit_compiler.h" #include "mlx/device.h" #include "mlx/graph_utils.h" diff --git a/mlx/backend/common/compiled_preamble.h b/mlx/backend/cpu/compiled_preamble.h similarity index 69% rename from mlx/backend/common/compiled_preamble.h rename to mlx/backend/cpu/compiled_preamble.h index feea71bcb..31ca1b468 100644 --- a/mlx/backend/common/compiled_preamble.h +++ b/mlx/backend/cpu/compiled_preamble.h @@ -5,8 +5,8 @@ // clang-format off #include "mlx/types/half_types.h" #include "mlx/types/complex.h" -#include "mlx/backend/common/unary_ops.h" -#include "mlx/backend/common/binary_ops.h" +#include "mlx/backend/cpu/unary_ops.h" +#include "mlx/backend/cpu/binary_ops.h" // clang-format on const char* get_kernel_preamble(); diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/cpu/conv.cpp similarity index 99% rename from mlx/backend/common/conv.cpp rename to mlx/backend/cpu/conv.cpp index b36f73b83..418a8cf25 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -3,8 +3,8 @@ #include #include -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/cpu/copy.cpp similarity index 99% rename from mlx/backend/common/copy.cpp rename to mlx/backend/cpu/copy.cpp index 41bffdcea..66c27e745 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -3,9 +3,9 @@ #include #include "mlx/allocator.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/simd/simd.h" namespace mlx::core { diff --git a/mlx/backend/cpu/copy.h b/mlx/backend/cpu/copy.h new file mode 100644 index 000000000..1e8dc2530 --- /dev/null +++ b/mlx/backend/cpu/copy.h @@ -0,0 +1,24 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +void copy(const array& src, array& dst, CopyType ctype); +void copy_inplace(const array& src, array& dst, CopyType ctype); + +void copy_inplace( + const array& src, + array& dst, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype); + +} // namespace mlx::core diff --git a/mlx/backend/common/eigh.cpp b/mlx/backend/cpu/eigh.cpp similarity index 97% rename from mlx/backend/common/eigh.cpp rename to mlx/backend/cpu/eigh.cpp index 7fa7b7fa8..be5e379f0 100644 --- a/mlx/backend/common/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -2,8 +2,8 @@ #include "mlx/allocator.h" #include "mlx/array.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" diff --git a/mlx/backend/common/fft.cpp b/mlx/backend/cpu/fft.cpp similarity index 100% rename from mlx/backend/common/fft.cpp rename to mlx/backend/cpu/fft.cpp diff --git a/mlx/backend/common/gemm.h b/mlx/backend/cpu/gemm.h similarity index 100% rename from mlx/backend/common/gemm.h rename to mlx/backend/cpu/gemm.h diff --git a/mlx/backend/common/gemms/bnns.cpp b/mlx/backend/cpu/gemms/bnns.cpp similarity index 99% rename from mlx/backend/common/gemms/bnns.cpp rename to mlx/backend/cpu/gemms/bnns.cpp index 5c5cee739..cd517f825 100644 --- a/mlx/backend/common/gemms/bnns.cpp +++ b/mlx/backend/cpu/gemms/bnns.cpp @@ -3,8 +3,8 @@ #include #include "mlx/array.h" -#include "mlx/backend/common/gemm.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/gemm.h" #include "mlx/dtype.h" namespace mlx::core { diff --git a/mlx/backend/common/gemms/cblas.cpp b/mlx/backend/cpu/gemms/cblas.cpp similarity index 92% rename from mlx/backend/common/gemms/cblas.cpp rename to mlx/backend/cpu/gemms/cblas.cpp index e6d07bf84..fef63b3e9 100644 --- a/mlx/backend/common/gemms/cblas.cpp +++ b/mlx/backend/cpu/gemms/cblas.cpp @@ -1,8 +1,8 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/common/gemm.h" -#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/gemm.h" +#include "mlx/backend/cpu/lapack.h" namespace mlx::core { diff --git a/mlx/backend/common/gemms/no_bf16.cpp b/mlx/backend/cpu/gemms/no_bf16.cpp similarity index 89% rename from mlx/backend/common/gemms/no_bf16.cpp rename to mlx/backend/cpu/gemms/no_bf16.cpp index 2abcf1536..bf470779d 100644 --- a/mlx/backend/common/gemms/no_bf16.cpp +++ b/mlx/backend/cpu/gemms/no_bf16.cpp @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/common/gemm.h" +#include "mlx/backend/cpu/gemm.h" namespace mlx::core { diff --git a/mlx/backend/common/gemms/no_fp16.cpp b/mlx/backend/cpu/gemms/no_fp16.cpp similarity index 89% rename from mlx/backend/common/gemms/no_fp16.cpp rename to mlx/backend/cpu/gemms/no_fp16.cpp index ccc2f2a31..7b39d4b30 100644 --- a/mlx/backend/common/gemms/no_fp16.cpp +++ b/mlx/backend/cpu/gemms/no_fp16.cpp @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/common/gemm.h" +#include "mlx/backend/cpu/gemm.h" namespace mlx::core { diff --git a/mlx/backend/common/hadamard.cpp b/mlx/backend/cpu/hadamard.cpp similarity index 98% rename from mlx/backend/common/hadamard.cpp rename to mlx/backend/cpu/hadamard.cpp index 4ee05345b..eaeac83db 100644 --- a/mlx/backend/common/hadamard.cpp +++ b/mlx/backend/cpu/hadamard.cpp @@ -2,8 +2,8 @@ #include -#include "mlx/backend/common/copy.h" #include "mlx/backend/common/hadamard.h" +#include "mlx/backend/cpu/copy.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/cpu/indexing.cpp similarity index 99% rename from mlx/backend/common/indexing.cpp rename to mlx/backend/cpu/indexing.cpp index 6798f0245..4eb48b921 100644 --- a/mlx/backend/common/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -6,8 +6,8 @@ #include "mlx/allocator.h" #include "mlx/primitives.h" -#include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/copy.h" namespace mlx::core { diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/cpu/inverse.cpp similarity index 97% rename from mlx/backend/common/inverse.cpp rename to mlx/backend/cpu/inverse.cpp index 23e294201..40cd16efc 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/allocator.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" int strtri_wrapper(char uplo, char diag, float* matrix, int N) { diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/cpu/jit_compiler.cpp similarity index 98% rename from mlx/backend/common/jit_compiler.cpp rename to mlx/backend/cpu/jit_compiler.cpp index d665c0012..0a7ff3eb0 100644 --- a/mlx/backend/common/jit_compiler.cpp +++ b/mlx/backend/cpu/jit_compiler.cpp @@ -1,6 +1,6 @@ // Copyright © 2024 Apple Inc. -#include "mlx/backend/common/jit_compiler.h" +#include "mlx/backend/cpu/jit_compiler.h" #include #include diff --git a/mlx/backend/common/jit_compiler.h b/mlx/backend/cpu/jit_compiler.h similarity index 100% rename from mlx/backend/common/jit_compiler.h rename to mlx/backend/cpu/jit_compiler.h diff --git a/mlx/backend/common/lapack.h b/mlx/backend/cpu/lapack.h similarity index 100% rename from mlx/backend/common/lapack.h rename to mlx/backend/cpu/lapack.h diff --git a/mlx/backend/common/make_compiled_preamble.ps1 b/mlx/backend/cpu/make_compiled_preamble.ps1 similarity index 97% rename from mlx/backend/common/make_compiled_preamble.ps1 rename to mlx/backend/cpu/make_compiled_preamble.ps1 index 18d057453..0cd2d1f17 100644 --- a/mlx/backend/common/make_compiled_preamble.ps1 +++ b/mlx/backend/cpu/make_compiled_preamble.ps1 @@ -8,7 +8,7 @@ $CL = $args[1] $SRCDIR = $args[2] # Get command result as array. -$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h" +$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" # Remove empty lines. # Otherwise there will be too much empty lines making the result unreadable. $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/cpu/make_compiled_preamble.sh similarity index 84% rename from mlx/backend/common/make_compiled_preamble.sh rename to mlx/backend/cpu/make_compiled_preamble.sh index 5f1019e21..04c7ff0c4 100644 --- a/mlx/backend/common/make_compiled_preamble.sh +++ b/mlx/backend/cpu/make_compiled_preamble.sh @@ -24,7 +24,7 @@ else CC_FLAGS="-std=c++17" fi -CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null) +CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null) cat << EOF > "$OUTPUT_FILE" const char* get_kernel_preamble() { diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp similarity index 99% rename from mlx/backend/common/masked_mm.cpp rename to mlx/backend/cpu/masked_mm.cpp index 5675399a3..5c9753e88 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -3,9 +3,9 @@ #include #include "mlx/array.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/matmul.cpp b/mlx/backend/cpu/matmul.cpp similarity index 96% rename from mlx/backend/common/matmul.cpp rename to mlx/backend/cpu/matmul.cpp index 1966c57b6..05989c328 100644 --- a/mlx/backend/common/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -2,8 +2,8 @@ #include #include "mlx/array.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/gemm.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/gemm.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/cpu/primitives.cpp similarity index 99% rename from mlx/backend/common/primitives.cpp rename to mlx/backend/cpu/primitives.cpp index 8cd0763d8..8ae2cc520 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -7,12 +7,12 @@ #include #include "mlx/allocator.h" -#include "mlx/backend/common/arange.h" -#include "mlx/backend/common/copy.h" #include "mlx/backend/common/load.h" #include "mlx/backend/common/slicing.h" -#include "mlx/backend/common/threefry.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/arange.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/threefry.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/backend/common/qrf.cpp b/mlx/backend/cpu/qrf.cpp similarity index 98% rename from mlx/backend/common/qrf.cpp rename to mlx/backend/cpu/qrf.cpp index 1c28eec26..d7caa8b68 100644 --- a/mlx/backend/common/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -1,8 +1,8 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/allocator.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/cpu/quantized.cpp similarity index 99% rename from mlx/backend/common/quantized.cpp rename to mlx/backend/cpu/quantized.cpp index e0883f490..38c8004e3 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -2,8 +2,8 @@ #include -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/simd/simd.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/simd/simd.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp new file mode 100644 index 000000000..11f27ea06 --- /dev/null +++ b/mlx/backend/cpu/reduce.cpp @@ -0,0 +1,552 @@ +// Copyright © 2023 Apple Inc. + +#include +#include +#include + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/cpu/simd/simd.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +template +struct Limits { + static const U max; + static const U min; +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr type max = std::numeric_limits::max(); \ + static constexpr type min = std::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static const type max; \ + static const type min; \ + }; + +instantiate_float_limit(float16_t); +instantiate_float_limit(bfloat16_t); +instantiate_float_limit(float); +instantiate_float_limit(complex64_t); + +template <> +struct Limits { + static constexpr bool max = true; + static constexpr bool min = false; +}; + +const float Limits::max = std::numeric_limits::infinity(); +const float Limits::min = -std::numeric_limits::infinity(); +const bfloat16_t Limits::max = + std::numeric_limits::infinity(); +const bfloat16_t Limits::min = + -std::numeric_limits::infinity(); +const float16_t Limits::max = std::numeric_limits::infinity(); +const float16_t Limits::min = + -std::numeric_limits::infinity(); +const complex64_t Limits::max = + std::numeric_limits::infinity(); +const complex64_t Limits::min = + -std::numeric_limits::infinity(); + +template +void strided_reduce( + const T* x, + U* accumulator, + int size, + size_t stride, + Op op) { + constexpr int N = std::min(simd::max_size, simd::max_size); + for (int i = 0; i < size; i++) { + U* moving_accumulator = accumulator; + auto s = stride; + while (s >= N) { + auto acc = simd::load(moving_accumulator); + auto v = simd::Simd(simd::load(x)); + simd::store(moving_accumulator, op(acc, v)); + moving_accumulator += N; + x += N; + s -= N; + } + while (s-- > 0) { + *moving_accumulator = op(*moving_accumulator, *x); + moving_accumulator++; + x++; + } + } +}; + +template +void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) { + constexpr int N = std::min(simd::max_size, simd::max_size); + simd::Simd accumulator_v(init); + while (size >= N) { + accumulator_v = op(accumulator_v, simd::Simd(simd::load(x))); + x += N; + size -= N; + } + *accumulator = op(*accumulator, op(accumulator_v)); + while (size-- > 0) { + *accumulator = op(*accumulator, *x); + x++; + } +} + +// Helper for the ndimensional strided loop +void nd_loop( + std::function callback, + const Shape& shape, + const Strides& strides) { + std::function loop_inner; + loop_inner = [&](int dim, int offset) { + if (dim < shape.size() - 1) { + auto size = shape[dim]; + auto stride = strides[dim]; + for (int i = 0; i < size; i++) { + loop_inner(dim + 1, offset + i * stride); + } + } else { + auto size = shape[dim]; + auto stride = strides[dim]; + for (int i = 0; i < size; i++) { + callback(offset + i * stride); + } + } + }; + loop_inner(0, 0); +} + +template +void reduction_op( + const array& x, + array& out, + const std::vector& axes, + U init, + Op op) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + ReductionPlan plan = get_reduction_plan(x, axes); + + if (plan.type == ContiguousAllReduce) { + U* out_ptr = out.data(); + *out_ptr = init; + contiguous_reduce(x.data(), out_ptr, x.size(), op, init); + return; + } + + if (plan.type == ContiguousReduce && plan.shape.size() == 1) { + int reduction_size = plan.shape[0]; + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) { + *out_ptr = init; + contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init); + } + return; + } + + if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) { + int reduction_size = plan.shape.back(); + plan.shape.pop_back(); + plan.strides.pop_back(); + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + // Unrolling the following loop (and implementing it in order for + // ContiguousReduce) should hold extra performance boost. + auto [shape, strides] = shapes_without_reduction_axes(x, axes); + if (plan.shape.size() == 0) { + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init); + } + } else { + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + nd_loop( + [&](int extra_offset) { + contiguous_reduce( + x_ptr + offset + extra_offset, + out_ptr, + reduction_size, + op, + init); + }, + plan.shape, + plan.strides); + } + } + return; + } + + if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) { + int reduction_size = plan.shape.back(); + size_t reduction_stride = plan.strides.back(); + plan.shape.pop_back(); + plan.strides.pop_back(); + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + for (int i = 0; i < out.size(); i += reduction_stride) { + std::fill_n(out_ptr, reduction_stride, init); + strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op); + x_ptr += reduction_stride * reduction_size; + out_ptr += reduction_stride; + } + return; + } + + if (plan.type == GeneralStridedReduce || + plan.type == ContiguousStridedReduce) { + int reduction_size = plan.shape.back(); + size_t reduction_stride = plan.strides.back(); + plan.shape.pop_back(); + plan.strides.pop_back(); + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + auto [shape, strides] = shapes_without_reduction_axes(x, axes); + if (plan.shape.size() == 0) { + for (int i = 0; i < out.size(); i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + strided_reduce( + x_ptr + offset, out_ptr, reduction_size, reduction_stride, op); + out_ptr += reduction_stride; + } + } else { + for (int i = 0; i < out.size(); i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + nd_loop( + [&](int extra_offset) { + strided_reduce( + x_ptr + offset + extra_offset, + out_ptr, + reduction_size, + reduction_stride, + op); + }, + plan.shape, + plan.strides); + out_ptr += reduction_stride; + } + } + return; + } + + if (plan.type == GeneralReduce) { + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + auto [shape, strides] = shapes_without_reduction_axes(x, axes); + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + U val = init; + nd_loop( + [&](int extra_offset) { + val = op(val, *(x_ptr + offset + extra_offset)); + }, + plan.shape, + plan.strides); + *out_ptr = val; + } + } +} + +struct AndReduce { + template + bool operator()(bool x, T y) { + return x & (y != 0); + } + + bool operator()(bool x, bool y) { + return x & y; + } + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x & (y != 0); + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x & y; + }; + + template + bool operator()(simd::Simd x) { + return simd::all(x); + }; +}; + +struct OrReduce { + template + bool operator()(bool x, T y) { + return x | (y != 0); + } + + bool operator()(bool x, bool y) { + return x | y; + } + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x | (y != 0); + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x | y; + }; + + template + bool operator()(simd::Simd x) { + return simd::any(x); + }; +}; + +struct MaxReduce { + template + T operator()(T y, T x) { + return (*this)(simd::Simd(x), simd::Simd(y)).value; + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return simd::maximum(x, y); + }; + + template + T operator()(simd::Simd x) { + return simd::max(x); + }; +}; + +struct MinReduce { + template + T operator()(T y, T x) { + return (*this)(simd::Simd(x), simd::Simd(y)).value; + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return simd::minimum(x, y); + }; + + template + T operator()(simd::Simd x) { + return simd::min(x); + }; +}; + +struct SumReduce { + template + U operator()(U y, T x) { + return x + y; + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return y + x; + }; + + template + T operator()(simd::Simd x) { + return simd::sum(x); + }; +}; + +struct ProdReduce { + template + U operator()(U y, T x) { + return x * y; + }; + + template + simd::Simd operator()(simd::Simd y, simd::Simd x) { + return x * y; + }; + + template + T operator()(simd::Simd x) { + return simd::prod(x); + }; +}; + +template +void reduce_dispatch_and_or( + const array& in, + array& out, + Reduce::ReduceType rtype, + const std::vector& axes) { + if (rtype == Reduce::And) { + reduction_op(in, out, axes, true, AndReduce()); + } else { + reduction_op(in, out, axes, false, OrReduce()); + } +} + +template +void reduce_dispatch_sum_prod( + const array& in, + array& out, + Reduce::ReduceType rtype, + const std::vector& axes) { + if (rtype == Reduce::Sum) { + if constexpr (std::is_integral_v && sizeof(InT) <= 4) { + reduction_op(in, out, axes, 0, SumReduce()); + } else { + reduction_op(in, out, axes, 0, SumReduce()); + } + } else { + if constexpr (std::is_integral_v && sizeof(InT) <= 4) { + reduction_op(in, out, axes, 1, ProdReduce()); + } else { + reduction_op(in, out, axes, 1, ProdReduce()); + } + } +} + +template +void reduce_dispatch_min_max( + const array& in, + array& out, + Reduce::ReduceType rtype, + const std::vector& axes) { + if (rtype == Reduce::Max) { + auto init = Limits::min; + reduction_op(in, out, axes, init, MaxReduce()); + } else { + auto init = Limits::max; + reduction_op(in, out, axes, init, MinReduce()); + } +} + +void Reduce::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + switch (reduce_type_) { + case Reduce::And: + case Reduce::Or: { + switch (in.dtype()) { + case bool_: + case uint8: + case int8: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case int16: + case uint16: + case float16: + case bfloat16: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case uint32: + case int32: + case float32: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + case uint64: + case int64: + case complex64: + reduce_dispatch_and_or(in, out, reduce_type_, axes_); + break; + } + break; + } + case Reduce::Sum: + case Reduce::Prod: { + switch (in.dtype()) { + case bool_: + case uint8: + case int8: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int16: + case uint16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int32: + case uint32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case int64: + case uint64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case float16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case bfloat16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case float32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case complex64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + } + break; + } + case Reduce::Max: + case Reduce::Min: { + switch (in.dtype()) { + case bool_: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint8: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case uint64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int8: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case int64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case float16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case float32: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case bfloat16: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + case complex64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; + } + break; + } + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/scan.cpp b/mlx/backend/cpu/scan.cpp similarity index 99% rename from mlx/backend/common/scan.cpp rename to mlx/backend/cpu/scan.cpp index 2430f3172..0c231baab 100644 --- a/mlx/backend/common/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -2,9 +2,9 @@ #include -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/select.cpp b/mlx/backend/cpu/select.cpp similarity index 95% rename from mlx/backend/common/select.cpp rename to mlx/backend/cpu/select.cpp index 04c28ef04..a08805893 100644 --- a/mlx/backend/common/select.cpp +++ b/mlx/backend/cpu/select.cpp @@ -2,8 +2,8 @@ #include -#include "mlx/backend/common/binary_ops.h" -#include "mlx/backend/common/ternary.h" +#include "mlx/backend/cpu/binary_ops.h" +#include "mlx/backend/cpu/ternary.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/simd/accelerate_fp16_simd.h b/mlx/backend/cpu/simd/accelerate_fp16_simd.h similarity index 94% rename from mlx/backend/common/simd/accelerate_fp16_simd.h rename to mlx/backend/cpu/simd/accelerate_fp16_simd.h index 7fa5c9467..1f21d2e18 100644 --- a/mlx/backend/common/simd/accelerate_fp16_simd.h +++ b/mlx/backend/cpu/simd/accelerate_fp16_simd.h @@ -1,9 +1,9 @@ #pragma once -#include "mlx/backend/common/simd/base_simd.h" +#include "mlx/backend/cpu/simd/base_simd.h" #if MLX_SIMD_LIBRARY_VERSION < 6 -#include "mlx/backend/common/simd/neon_fp16_simd.h" +#include "mlx/backend/cpu/simd/neon_fp16_simd.h" #endif namespace mlx::core::simd { diff --git a/mlx/backend/common/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h similarity index 98% rename from mlx/backend/common/simd/accelerate_simd.h rename to mlx/backend/cpu/simd/accelerate_simd.h index 7edb06df5..59821d03b 100644 --- a/mlx/backend/common/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -7,7 +7,7 @@ #include #include -#include "mlx/backend/common/simd/base_simd.h" +#include "mlx/backend/cpu/simd/base_simd.h" // There seems to be a bug in sims/base.h // __XROS_2_0 is not defined, the expression evaluates @@ -299,5 +299,5 @@ T prod(Simd x) { } // namespace mlx::core::simd #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#include "mlx/backend/common/simd/accelerate_fp16_simd.h" +#include "mlx/backend/cpu/simd/accelerate_fp16_simd.h" #endif diff --git a/mlx/backend/common/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h similarity index 100% rename from mlx/backend/common/simd/base_simd.h rename to mlx/backend/cpu/simd/base_simd.h diff --git a/mlx/backend/common/simd/math.h b/mlx/backend/cpu/simd/math.h similarity index 99% rename from mlx/backend/common/simd/math.h rename to mlx/backend/cpu/simd/math.h index c7061b2b1..3730aac5e 100644 --- a/mlx/backend/common/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/common/simd/type.h" +#include "mlx/backend/cpu/simd/type.h" namespace mlx::core::simd { diff --git a/mlx/backend/common/simd/neon_fp16_simd.h b/mlx/backend/cpu/simd/neon_fp16_simd.h similarity index 99% rename from mlx/backend/common/simd/neon_fp16_simd.h rename to mlx/backend/cpu/simd/neon_fp16_simd.h index 269ff1305..5d32042cc 100644 --- a/mlx/backend/common/simd/neon_fp16_simd.h +++ b/mlx/backend/cpu/simd/neon_fp16_simd.h @@ -2,7 +2,7 @@ #include -#include "mlx/backend/common/simd/base_simd.h" +#include "mlx/backend/cpu/simd/base_simd.h" namespace mlx::core::simd { diff --git a/mlx/backend/cpu/simd/simd.h b/mlx/backend/cpu/simd/simd.h new file mode 100644 index 000000000..8700f24c0 --- /dev/null +++ b/mlx/backend/cpu/simd/simd.h @@ -0,0 +1,4 @@ +#pragma once + +#include "mlx/backend/cpu/simd/math.h" +#include "mlx/backend/cpu/simd/type.h" diff --git a/mlx/backend/cpu/simd/type.h b/mlx/backend/cpu/simd/type.h new file mode 100644 index 000000000..c24da22b2 --- /dev/null +++ b/mlx/backend/cpu/simd/type.h @@ -0,0 +1,7 @@ +#pragma once + +#include "mlx/backend/cpu/simd/base_simd.h" + +#ifdef MLX_USE_ACCELERATE +#include "mlx/backend/cpu/simd/accelerate_simd.h" +#endif diff --git a/mlx/backend/cpu/slicing.h b/mlx/backend/cpu/slicing.h new file mode 100644 index 000000000..eda37320d --- /dev/null +++ b/mlx/backend/cpu/slicing.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +std::tuple prepare_slice( + const array& in, + const Shape& start_indices, + const Shape& strides); + +void shared_buffer_slice( + const array& in, + const Strides& out_strides, + size_t data_offset, + size_t data_size, + array& out); + +} // namespace mlx::core diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/cpu/softmax.cpp similarity index 98% rename from mlx/backend/common/softmax.cpp rename to mlx/backend/cpu/softmax.cpp index 2c7579930..3c80d7f28 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -3,8 +3,8 @@ #include #include -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/simd/simd.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/sort.cpp b/mlx/backend/cpu/sort.cpp similarity index 99% rename from mlx/backend/common/sort.cpp rename to mlx/backend/cpu/sort.cpp index 1304186d6..078b68ade 100644 --- a/mlx/backend/common/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -5,8 +5,8 @@ #include #include -#include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/copy.h" #include "mlx/primitives.h" diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/cpu/svd.cpp similarity index 98% rename from mlx/backend/common/svd.cpp rename to mlx/backend/cpu/svd.cpp index 71c620db1..f18ab4f91 100644 --- a/mlx/backend/common/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -1,8 +1,8 @@ // Copyright © 2024 Apple Inc. #include "mlx/allocator.h" -#include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/cpu/ternary.h b/mlx/backend/cpu/ternary.h new file mode 100644 index 000000000..87c27c86c --- /dev/null +++ b/mlx/backend/cpu/ternary.h @@ -0,0 +1,157 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +template +void ternary_op_dims( + const T1* a, + const T2* b, + const T3* c, + U* out, + Op op, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& c_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_c = c_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + ternary_op_dims( + a, + b, + c, + out, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + axis + 1); + } else { + *out = op(*a, *b, *c); + } + a += stride_a; + b += stride_b; + c += stride_c; + out += stride_out; + } +} + +template +void ternary_op_dispatch_dims( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& c_strides = strides[2]; + const auto& out_strides = strides[3]; + + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* out_ptr = out.data(); + int ndim = shape.size(); + switch (ndim) { + case 1: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + case 2: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + ContiguousIterator c_it(shape, c_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < a.size(); elem += stride) { + ternary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + c_ptr + c_it.loc, + out_ptr + elem, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + c_it.step(); + } +} + +template +void ternary_op( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + + // The full computation is scalar-scalar-scalar so we call the base op once. + if (topt == TernaryOpType::ScalarScalarScalar) { + *(out.data()) = op(*a.data(), *b.data(), *c.data()); + } else if (topt == TernaryOpType::VectorVectorVector) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* out_ptr = out.data(); + for (size_t i = 0; i < out.size(); ++i) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + a_ptr++; + b_ptr++; + c_ptr++; + out_ptr++; + } + } else { + ternary_op_dispatch_dims(a, b, c, out, op); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/threefry.cpp b/mlx/backend/cpu/threefry.cpp similarity index 95% rename from mlx/backend/common/threefry.cpp rename to mlx/backend/cpu/threefry.cpp index b4905acbe..8056b842c 100644 --- a/mlx/backend/common/threefry.cpp +++ b/mlx/backend/cpu/threefry.cpp @@ -1,6 +1,6 @@ // Copyright © 2023 Apple Inc. -#include "mlx/backend/common/threefry.h" +#include "mlx/backend/cpu/threefry.h" namespace mlx::core::random { diff --git a/mlx/backend/common/threefry.h b/mlx/backend/cpu/threefry.h similarity index 100% rename from mlx/backend/common/threefry.h rename to mlx/backend/cpu/threefry.h diff --git a/mlx/backend/common/unary.cpp b/mlx/backend/cpu/unary.cpp similarity index 98% rename from mlx/backend/common/unary.cpp rename to mlx/backend/cpu/unary.cpp index be9fec715..c6431baec 100644 --- a/mlx/backend/common/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -2,8 +2,8 @@ #include -#include "mlx/backend/common/unary.h" -#include "mlx/backend/common/unary_ops.h" +#include "mlx/backend/cpu/unary.h" +#include "mlx/backend/cpu/unary_ops.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/backend/common/unary.h b/mlx/backend/cpu/unary.h similarity index 97% rename from mlx/backend/common/unary.h rename to mlx/backend/cpu/unary.h index e38937d3b..6dccaf615 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/cpu/unary.h @@ -4,14 +4,12 @@ #include "mlx/allocator.h" #include "mlx/array.h" -#include "mlx/backend/common/simd/simd.h" #include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/simd/simd.h" #include "mlx/utils.h" namespace mlx::core { -namespace { - void set_unary_output_data(const array& in, array& out) { if (is_donatable(in, out)) { out.copy_shared_buffer(in); @@ -137,6 +135,4 @@ void unary_fp(const array& a, array& out, Op op) { } } -} // namespace - } // namespace mlx::core diff --git a/mlx/backend/common/unary_ops.h b/mlx/backend/cpu/unary_ops.h similarity index 98% rename from mlx/backend/common/unary_ops.h rename to mlx/backend/cpu/unary_ops.h index 11a69c2ca..3019ad91e 100644 --- a/mlx/backend/common/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -6,7 +6,7 @@ #include #include -#include "mlx/backend/common/simd/simd.h" +#include "mlx/backend/cpu/simd/simd.h" namespace mlx::core::detail { diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 16d2db362..88d63ba72 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 83aa18b88..52131b4a8 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -6,6 +6,7 @@ #include #include "mlx/3rdparty/pocketfft.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/binary.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/kernels.h" diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 09ca27a4e..707e43fa1 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt index 029c720d7..ac17d8059 100644 --- a/mlx/backend/no_cpu/CMakeLists.txt +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -1,10 +1,2 @@ -target_sources( - mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cpp) +target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) diff --git a/mlx/backend/common/compiled_nocpu.cpp b/mlx/backend/no_cpu/compiled.cpp similarity index 91% rename from mlx/backend/common/compiled_nocpu.cpp rename to mlx/backend/no_cpu/compiled.cpp index 3e081d1a6..c1c42c735 100644 --- a/mlx/backend/common/compiled_nocpu.cpp +++ b/mlx/backend/no_cpu/compiled.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/common/compiled.h" +#include "mlx/compile_impl.h" +#include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index b233df1b5..41ef03d97 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -3,7 +3,7 @@ #include #include -#include "mlx/backend/common/copy.h" +#include "mlx/backend/cpu/copy.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/mpi/mpi.h" diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index c30de91a2..ad3f2e0a5 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -13,7 +13,7 @@ #include -#include "mlx/backend/common/copy.h" +#include "mlx/backend/cpu/copy.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/threadpool.h"