mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Refactor common into cpu specific and truly common (#1817)
* refactor * fix extension example * fix no-cpu
This commit is contained in:
parent
ec7c7def40
commit
1156c84e86
@ -6,6 +6,7 @@
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
@ -29,8 +29,10 @@ if(WIN32)
|
||||
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
@ -1,88 +1,8 @@
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
set(COMPILE_DEPS
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
simd/simd.h
|
||||
simd/base_simd.h
|
||||
simd/math.h
|
||||
simd/type.h
|
||||
unary_ops.h
|
||||
binary_ops.h)
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
else()
|
||||
set(SHELL_EXT sh)
|
||||
set(SHELL_CMD bash)
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
|
||||
${COMPILE_DEPS})
|
||||
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||
endif()
|
||||
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
|
||||
endif()
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
@ -1,18 +1,13 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
@ -21,7 +16,7 @@ enum class BinaryOpType {
|
||||
General,
|
||||
};
|
||||
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = BinaryOpType::ScalarScalar;
|
||||
@ -39,7 +34,7 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
return bopt;
|
||||
}
|
||||
|
||||
void set_binary_op_output_data(
|
||||
inline void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
@ -124,361 +119,4 @@ void set_binary_op_output_data(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
struct VectorScalar {
|
||||
Op op;
|
||||
|
||||
VectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
dst += N;
|
||||
a += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
struct ScalarVector {
|
||||
Op op;
|
||||
|
||||
ScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
struct VectorVector {
|
||||
Op op;
|
||||
|
||||
VectorVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
a += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
}
|
||||
}
|
||||
out += stride_out;
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
dim - 3);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorScalar{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
ScalarVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
binary_op<T, T>(a, b, out, op);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -23,17 +22,4 @@ enum class CopyType {
|
||||
GeneralGeneral
|
||||
};
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,40 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* Approximation to the inverse error function.
|
||||
* Based on code from:
|
||||
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
||||
*/
|
||||
float erfinv(float a) {
|
||||
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||
t = std::log(t);
|
||||
float p;
|
||||
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,377 +1,147 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const U max;
|
||||
static const U min;
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr type max = std::numeric_limits<type>::max(); \
|
||||
static constexpr type min = std::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static const type max; \
|
||||
static const type min; \
|
||||
};
|
||||
|
||||
instantiate_float_limit(float16_t);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(complex64_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr bool max = true;
|
||||
static constexpr bool min = false;
|
||||
};
|
||||
|
||||
const float Limits<float>::max = std::numeric_limits<float>::infinity();
|
||||
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
|
||||
struct AndReduce {
|
||||
template <typename T>
|
||||
bool operator()(bool x, T y) {
|
||||
return x & (y != 0);
|
||||
}
|
||||
|
||||
bool operator()(bool x, bool y) {
|
||||
return x & y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x & (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x & y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::all(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct OrReduce {
|
||||
template <typename T>
|
||||
bool operator()(bool x, T y) {
|
||||
return x | (y != 0);
|
||||
}
|
||||
|
||||
bool operator()(bool x, bool y) {
|
||||
return x | y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x | (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x | y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::any(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::maximum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::minimum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct SumReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x + y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return y + x;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::sum(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ProdReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::prod(x);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
// Remove singleton axes from the plan
|
||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
||||
if (shape[i] == 1) {
|
||||
shape.erase(shape.begin() + i);
|
||||
strides.erase(strides.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const Shape& shape,
|
||||
const Strides& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
} else {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, int64_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
bool a_is_zero = a.second == 0;
|
||||
bool b_is_zero = b.second == 0;
|
||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int64_t size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto stride_i = x.strides()[i];
|
||||
auto shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
have_expand = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (stride_i != size && shape_i != 1) {
|
||||
break;
|
||||
}
|
||||
size *= shape_i;
|
||||
}
|
||||
// In the case of an expanded dimension we are being conservative and
|
||||
// require the smallest reduction stride to be smaller than the maximum row
|
||||
// contiguous size. The reason is that we can't easily know if the reduced
|
||||
// axis is before or after an expanded dimension.
|
||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -49,193 +48,8 @@ struct ReductionPlan {
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const Shape& shape,
|
||||
const Strides& strides);
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void strided_reduce(
|
||||
const T* x,
|
||||
U* accumulator,
|
||||
int size,
|
||||
size_t stride,
|
||||
Op op) {
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
for (int i = 0; i < size; i++) {
|
||||
U* moving_accumulator = accumulator;
|
||||
auto s = stride;
|
||||
while (s >= N) {
|
||||
auto acc = simd::load<U, N>(moving_accumulator);
|
||||
auto v = simd::Simd<U, N>(simd::load<T, N>(x));
|
||||
simd::store<U, N>(moving_accumulator, op(acc, v));
|
||||
moving_accumulator += N;
|
||||
x += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*moving_accumulator = op(*moving_accumulator, *x);
|
||||
moving_accumulator++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
simd::Simd<U, N> accumulator_v(init);
|
||||
while (size >= N) {
|
||||
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
|
||||
x += N;
|
||||
size -= N;
|
||||
}
|
||||
*accumulator = op(*accumulator, op(accumulator_v));
|
||||
while (size-- > 0) {
|
||||
*accumulator = op(*accumulator, *x);
|
||||
x++;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Op op) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
U* out_ptr = out.data<U>();
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
|
||||
int reduction_size = plan.shape.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
op,
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape.back();
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
|
||||
x_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralStridedReduce ||
|
||||
plan.type == ContiguousStridedReduce) {
|
||||
int reduction_size = plan.shape.back();
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
op);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = op(val, *(x_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,147 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
Shape shape = {x.shape(axes[0])};
|
||||
Strides strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove singleton axes from the plan
|
||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
||||
if (shape[i] == 1) {
|
||||
shape.erase(shape.begin() + i);
|
||||
strides.erase(strides.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, int64_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
bool a_is_zero = a.second == 0;
|
||||
bool b_is_zero = b.second == 0;
|
||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int64_t size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto stride_i = x.strides()[i];
|
||||
auto shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
have_expand = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (stride_i != size && shape_i != 1) {
|
||||
break;
|
||||
}
|
||||
size *= shape_i;
|
||||
}
|
||||
// In the case of an expanded dimension we are being conservative and
|
||||
// require the smallest reduction stride to be smaller than the maximum row
|
||||
// contiguous size. The reason is that we can't easily know if the reduced
|
||||
// axis is before or after an expanded dimension.
|
||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,4 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/math.h"
|
||||
#include "mlx/backend/common/simd/type.h"
|
@ -1,7 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include "mlx/backend/common/simd/accelerate_simd.h"
|
||||
#endif
|
@ -7,8 +7,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: Add support for more combinations of input types.
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
@ -16,7 +14,7 @@ enum class TernaryOpType {
|
||||
General,
|
||||
};
|
||||
|
||||
TernaryOpType
|
||||
inline TernaryOpType
|
||||
get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
TernaryOpType topt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
||||
@ -33,7 +31,7 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
return topt;
|
||||
}
|
||||
|
||||
void set_ternary_op_output_data(
|
||||
inline void set_ternary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
@ -76,152 +74,5 @@ void set_ternary_op_output_data(
|
||||
break;
|
||||
}
|
||||
}
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
|
||||
void ternary_op_dims(
|
||||
const T1* a,
|
||||
const T2* b,
|
||||
const T3* c,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& c_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_c = c_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
out,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
axis + 1);
|
||||
} else {
|
||||
*out = op(*a, *b, *c);
|
||||
}
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
c += stride_c;
|
||||
out += stride_out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& c_strides = strides[2];
|
||||
const auto& out_strides = strides[3];
|
||||
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<T3>();
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 1>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
c_ptr + c_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
ndim - 2);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
c_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
// The full computation is scalar-scalar-scalar so we call the base op once.
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
c_ptr++;
|
||||
out_ptr++;
|
||||
}
|
||||
} else {
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
81
mlx/backend/cpu/CMakeLists.txt
Normal file
81
mlx/backend/cpu/CMakeLists.txt
Normal file
@ -0,0 +1,81 @@
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
set(COMPILE_DEPS
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
simd/simd.h
|
||||
simd/base_simd.h
|
||||
simd/math.h
|
||||
simd/type.h
|
||||
unary_ops.h
|
||||
binary_ops.h)
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
else()
|
||||
set(SHELL_EXT sh)
|
||||
set(SHELL_CMD bash)
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
|
||||
${COMPILE_DEPS})
|
||||
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||
endif()
|
||||
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../no_cpu/compiled.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
|
||||
endif()
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -5,9 +5,9 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/binary_two.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
370
mlx/backend/cpu/binary.h
Normal file
370
mlx/backend/cpu/binary.h
Normal file
@ -0,0 +1,370 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename Op>
|
||||
struct VectorScalar {
|
||||
Op op;
|
||||
|
||||
VectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
dst += N;
|
||||
a += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
struct ScalarVector {
|
||||
Op op;
|
||||
|
||||
ScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
struct VectorVector {
|
||||
Op op;
|
||||
|
||||
VectorVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
a += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
}
|
||||
}
|
||||
out += stride_out;
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
dim - 3);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorScalar{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
ScalarVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
binary_op<T, T>(a, b, out, op);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@ -10,8 +10,8 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
#include "mlx/backend/cpu/compiled_preamble.h"
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
|
@ -5,8 +5,8 @@
|
||||
// clang-format off
|
||||
#include "mlx/types/half_types.h"
|
||||
#include "mlx/types/complex.h"
|
||||
#include "mlx/backend/common/unary_ops.h"
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
#include "mlx/backend/cpu/unary_ops.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
// clang-format on
|
||||
|
||||
const char* get_kernel_preamble();
|
@ -3,8 +3,8 @@
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
@ -3,9 +3,9 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
24
mlx/backend/cpu/copy.h
Normal file
24
mlx/backend/cpu/copy.h
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@ -3,8 +3,8 @@
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -6,8 +6,8 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
@ -1,6 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <vector>
|
@ -8,7 +8,7 @@ $CL = $args[1]
|
||||
$SRCDIR = $args[2]
|
||||
|
||||
# Get command result as array.
|
||||
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
|
||||
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h"
|
||||
# Remove empty lines.
|
||||
# Otherwise there will be too much empty lines making the result unreadable.
|
||||
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
|
@ -24,7 +24,7 @@ else
|
||||
CC_FLAGS="-std=c++17"
|
||||
fi
|
||||
|
||||
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
const char* get_kernel_preamble() {
|
@ -3,9 +3,9 @@
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <cstring>
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/gemm.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -7,12 +7,12 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/arange.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/threefry.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
552
mlx/backend/cpu/reduce.cpp
Normal file
552
mlx/backend/cpu/reduce.cpp
Normal file
@ -0,0 +1,552 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const U max;
|
||||
static const U min;
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static constexpr type max = std::numeric_limits<type>::max(); \
|
||||
static constexpr type min = std::numeric_limits<type>::min(); \
|
||||
};
|
||||
|
||||
instantiate_default_limit(uint8_t);
|
||||
instantiate_default_limit(uint16_t);
|
||||
instantiate_default_limit(uint32_t);
|
||||
instantiate_default_limit(uint64_t);
|
||||
instantiate_default_limit(int8_t);
|
||||
instantiate_default_limit(int16_t);
|
||||
instantiate_default_limit(int32_t);
|
||||
instantiate_default_limit(int64_t);
|
||||
|
||||
#define instantiate_float_limit(type) \
|
||||
template <> \
|
||||
struct Limits<type> { \
|
||||
static const type max; \
|
||||
static const type min; \
|
||||
};
|
||||
|
||||
instantiate_float_limit(float16_t);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(complex64_t);
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr bool max = true;
|
||||
static constexpr bool min = false;
|
||||
};
|
||||
|
||||
const float Limits<float>::max = std::numeric_limits<float>::infinity();
|
||||
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const bfloat16_t Limits<bfloat16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void strided_reduce(
|
||||
const T* x,
|
||||
U* accumulator,
|
||||
int size,
|
||||
size_t stride,
|
||||
Op op) {
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
for (int i = 0; i < size; i++) {
|
||||
U* moving_accumulator = accumulator;
|
||||
auto s = stride;
|
||||
while (s >= N) {
|
||||
auto acc = simd::load<U, N>(moving_accumulator);
|
||||
auto v = simd::Simd<U, N>(simd::load<T, N>(x));
|
||||
simd::store<U, N>(moving_accumulator, op(acc, v));
|
||||
moving_accumulator += N;
|
||||
x += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*moving_accumulator = op(*moving_accumulator, *x);
|
||||
moving_accumulator++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
|
||||
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
||||
simd::Simd<U, N> accumulator_v(init);
|
||||
while (size >= N) {
|
||||
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
|
||||
x += N;
|
||||
size -= N;
|
||||
}
|
||||
*accumulator = op(*accumulator, op(accumulator_v));
|
||||
while (size-- > 0) {
|
||||
*accumulator = op(*accumulator, *x);
|
||||
x++;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const Shape& shape,
|
||||
const Strides& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
auto size = shape[dim];
|
||||
auto stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Op op) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
U* out_ptr = out.data<U>();
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
|
||||
int reduction_size = plan.shape.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
op,
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape.back();
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
|
||||
x_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralStridedReduce ||
|
||||
plan.type == ContiguousStridedReduce) {
|
||||
int reduction_size = plan.shape.back();
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
op);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = op(val, *(x_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AndReduce {
|
||||
template <typename T>
|
||||
bool operator()(bool x, T y) {
|
||||
return x & (y != 0);
|
||||
}
|
||||
|
||||
bool operator()(bool x, bool y) {
|
||||
return x & y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x & (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x & y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::all(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct OrReduce {
|
||||
template <typename T>
|
||||
bool operator()(bool x, T y) {
|
||||
return x | (y != 0);
|
||||
}
|
||||
|
||||
bool operator()(bool x, bool y) {
|
||||
return x | y;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
|
||||
return x | (y != 0);
|
||||
};
|
||||
|
||||
template <int N>
|
||||
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
|
||||
return x | y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
bool operator()(simd::Simd<T, N> x) {
|
||||
return simd::any(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::maximum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
|
||||
return simd::minimum(x, y);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct SumReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x + y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return y + x;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::sum(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ProdReduce {
|
||||
template <typename T, typename U>
|
||||
U operator()(U y, T x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T, typename U>
|
||||
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
|
||||
return x * y;
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::prod(x);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/ternary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -1,9 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
#if MLX_SIMD_LIBRARY_VERSION < 6
|
||||
#include "mlx/backend/common/simd/neon_fp16_simd.h"
|
||||
#include "mlx/backend/cpu/simd/neon_fp16_simd.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core::simd {
|
@ -7,7 +7,7 @@
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
// There seems to be a bug in sims/base.h
|
||||
// __XROS_2_0 is not defined, the expression evaluates
|
||||
@ -299,5 +299,5 @@ T prod(Simd<T, N> x) {
|
||||
} // namespace mlx::core::simd
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include "mlx/backend/common/simd/accelerate_fp16_simd.h"
|
||||
#include "mlx/backend/cpu/simd/accelerate_fp16_simd.h"
|
||||
#endif
|
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/type.h"
|
||||
#include "mlx/backend/cpu/simd/type.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
4
mlx/backend/cpu/simd/simd.h
Normal file
4
mlx/backend/cpu/simd/simd.h
Normal file
@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/math.h"
|
||||
#include "mlx/backend/cpu/simd/type.h"
|
7
mlx/backend/cpu/simd/type.h
Normal file
7
mlx/backend/cpu/simd/type.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
||||
#endif
|
21
mlx/backend/cpu/slicing.h
Normal file
21
mlx/backend/cpu/slicing.h
Normal file
@ -0,0 +1,21 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<int64_t, Strides> prepare_slice(
|
||||
const array& in,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
|
||||
} // namespace mlx::core
|
@ -3,8 +3,8 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -5,8 +5,8 @@
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
157
mlx/backend/cpu/ternary.h
Normal file
157
mlx/backend/cpu/ternary.h
Normal file
@ -0,0 +1,157 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
|
||||
void ternary_op_dims(
|
||||
const T1* a,
|
||||
const T2* b,
|
||||
const T3* c,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& c_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_c = c_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
out,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
axis + 1);
|
||||
} else {
|
||||
*out = op(*a, *b, *c);
|
||||
}
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
c += stride_c;
|
||||
out += stride_out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& c_strides = strides[2];
|
||||
const auto& out_strides = strides[3];
|
||||
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<T3>();
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 1>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
ContiguousIterator c_it(shape, c_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
c_ptr + c_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
out_strides,
|
||||
ndim - 2);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
c_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
// The full computation is scalar-scalar-scalar so we call the base op once.
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
c_ptr++;
|
||||
out_ptr++;
|
||||
}
|
||||
} else {
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -1,6 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/cpu/threefry.h"
|
||||
|
||||
namespace mlx::core::random {
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/unary_ops.h"
|
||||
#include "mlx/backend/cpu/unary.h"
|
||||
#include "mlx/backend/cpu/unary_ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
@ -4,14 +4,12 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
@ -137,6 +135,4 @@ void unary_fp(const array& a, array& out, Op op) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
@ -6,7 +6,7 @@
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
|
@ -1,10 +1,2 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <dlfcn.h>
|
||||
#include <mpi.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/distributed/mpi/mpi.h"
|
||||
|
@ -13,7 +13,7 @@
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
Loading…
Reference in New Issue
Block a user