Refactor common into cpu specific and truly common (#1817)

* refactor

* fix extension example

* fix no-cpu
This commit is contained in:
Awni Hannun
2025-02-03 15:58:02 -08:00
committed by GitHub
parent ec7c7def40
commit 1156c84e86
72 changed files with 1426 additions and 1434 deletions

View File

@@ -1,88 +1,8 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
set(COMPILE_DEPS
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
simd/simd.h
simd/base_simd.h
simd/math.h
simd/type.h
unary_ops.h
binary_ops.h)
if(MSVC)
set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
else()
set(SHELL_EXT sh)
set(SHELL_CMD bash)
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
${COMPILE_DEPS})
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
endif()
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
endif()
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)

View File

@@ -1,74 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, size_t size) {
auto ptr = out.data<T>();
auto step_size = next - start;
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
}
} // namespace
void arange(
const std::vector<array>& inputs,
array& out,
double start,
double step) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start, start + step, out, out.size());
break;
case uint16:
arange<uint16_t>(start, start + step, out, out.size());
break;
case uint32:
arange<uint32_t>(start, start + step, out, out.size());
break;
case uint64:
arange<uint64_t>(start, start + step, out, out.size());
break;
case int8:
arange<int8_t>(start, start + step, out, out.size());
break;
case int16:
arange<int16_t>(start, start + step, out, out.size());
break;
case int32:
arange<int32_t>(start, start + step, out, out.size());
break;
case int64:
arange<int64_t>(start, start + step, out, out.size());
break;
case float16:
arange<float16_t>(start, start + step, out, out.size());
break;
case float32:
arange<float>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;
case complex64:
arange<complex64_t>(start, start + step, out, out.size());
break;
}
}
} // namespace mlx::core

View File

@@ -1,112 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/primitives.h"
#include "utils.h"
namespace mlx::core {
namespace {
template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto in_ptr = in.data<InT>() + loc;
uint32_t ind_v = 0;
InT v = (*in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
op(j, (*in_ptr), &ind_v, &v);
}
out.data<uint32_t>()[i] = ind_v;
}
}
template <typename InT>
void arg_reduce_dispatch(
const array& in,
array& out,
ArgReduce::ReduceType rtype,
int axis) {
switch (rtype) {
case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x < (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
break;
}
case ArgReduce::ArgMax: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x > (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
break;
}
}
}
} // namespace
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
}
} // namespace mlx::core

View File

@@ -1,341 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <cmath>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_ops.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t, bool>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t, bool>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t, bool>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t, bool>(a, b, out, op);
break;
case int8:
binary_op<int8_t, bool>(a, b, out, op);
break;
case int16:
binary_op<int16_t, bool>(a, b, out, op);
break;
case int32:
binary_op<int32_t, bool>(a, b, out, op);
break;
case int64:
binary_op<int64_t, bool>(a, b, out, op);
break;
case float16:
binary_op<float16_t, bool>(a, b, out, op);
break;
case float32:
binary_op<float, bool>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, op);
break;
}
}
} // namespace
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
}
void DivMod::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (equal_nan_) {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
break;
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
} else {
comparison_op(a, b, out, detail::Equal());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
" non floating point type.");
}
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
}
} // namespace mlx::core

View File

@@ -1,18 +1,13 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/common/simd/simd.h"
namespace mlx::core {
namespace {
enum class BinaryOpType {
ScalarScalar,
ScalarVector,
@@ -21,7 +16,7 @@ enum class BinaryOpType {
General,
};
BinaryOpType get_binary_op_type(const array& a, const array& b) {
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = BinaryOpType::ScalarScalar;
@@ -39,7 +34,7 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
return bopt;
}
void set_binary_op_output_data(
inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
@@ -124,361 +119,4 @@ void set_binary_op_output_data(
}
}
template <typename Op>
struct VectorScalar {
Op op;
VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N;
a += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, scalar);
dst++;
a++;
}
}
};
template <typename Op>
struct ScalarVector {
Op op;
ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(scalar, *b);
dst++;
b++;
}
}
};
template <typename Op>
struct VectorVector {
Op op;
VectorVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N;
a += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, *b);
dst++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op, int D, bool Strided>
void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
} else {
*out = op(*a, *b);
}
}
out += stride_out;
a += stride_a;
b += stride_b;
}
}
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
int dim,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
}
}
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorScalar{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
ScalarVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break;
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T>(a, b, out, op);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
case float16:
binary_op<float16_t>(a, b, out, op);
break;
case float32:
binary_op<float>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t>(a, b, out, op);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@@ -1,98 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/backend/common/simd/simd.h"
namespace mlx::core::detail {
using namespace mlx::core::simd;
#define BINARY_SINGLE() \
template <typename T> \
T operator()(T x, T y) { \
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
}
#define DEFAULT_BINARY_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
return op(x, y); \
} \
BINARY_SINGLE() \
};
DEFAULT_BINARY_OP(Add, operator+)
DEFAULT_BINARY_OP(ArcTan2, atan2)
DEFAULT_BINARY_OP(Divide, operator/)
DEFAULT_BINARY_OP(Multiply, operator*)
DEFAULT_BINARY_OP(Subtract, operator-)
DEFAULT_BINARY_OP(LogicalAnd, operator&&)
DEFAULT_BINARY_OP(LogicalOr, operator||)
DEFAULT_BINARY_OP(BitwiseAnd, operator&)
DEFAULT_BINARY_OP(BitwiseOr, operator|)
DEFAULT_BINARY_OP(BitwiseXor, operator^)
DEFAULT_BINARY_OP(LeftShift, operator<<)
DEFAULT_BINARY_OP(RightShift, operator>>)
DEFAULT_BINARY_OP(Remainder, remainder)
DEFAULT_BINARY_OP(Maximum, maximum)
DEFAULT_BINARY_OP(Minimum, minimum)
DEFAULT_BINARY_OP(Power, pow)
#define DEFAULT_BOOL_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
return op(x, y); \
} \
template <typename T> \
bool operator()(T x, T y) { \
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
} \
};
DEFAULT_BOOL_OP(Equal, operator==)
DEFAULT_BOOL_OP(Greater, operator>)
DEFAULT_BOOL_OP(GreaterEqual, operator>=)
DEFAULT_BOOL_OP(Less, operator<)
DEFAULT_BOOL_OP(LessEqual, operator<=)
DEFAULT_BOOL_OP(NotEqual, operator!=)
struct NaNEqual {
template <int N, typename T>
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) {
return x == y || (isnan(x) && isnan(y));
}
template <typename T>
bool operator()(T x, T y) {
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;
}
};
struct LogAddExp {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) {
auto maxval = maximum(x, y);
auto minval = minimum(x, y);
auto mask = minval == -inf || maxval == inf;
auto out = maxval + log1p(exp(minval - maxval));
return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));
}
BINARY_SINGLE()
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))
.value;
}
template <int N, typename T>
Simd<T, N> operator()(Simd<bool, N> condition, Simd<T, N> x, Simd<T, N> y) {
return select(condition, x, y);
}
};
} // namespace mlx::core::detail

View File

@@ -1,219 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op, int D>
void binary_op_dims(
const T* a,
const T* b,
U* out_a,
U* out_b,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1>(
a,
b,
out_a,
out_b,
op,
shape,
a_strides,
b_strides,
out_strides,
axis + 1);
} else {
std::tie(*out_a, *out_b) = op(*a, *b);
}
a += stride_a;
b += stride_b;
out_a += stride_out;
out_b += stride_out;
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
int ndim = shape.size();
switch (ndim) {
case 1:
binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
}
}
template <typename T, typename U = T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return;
}
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
} else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
} else { // VectorVector
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
}
}
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@@ -1,74 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
// and that a column-major lower triangular matrix is a row-major upper
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
}
}
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -1,373 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <filesystem>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
namespace mlx::core {
struct CompilerCache {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
std::list<DLib> libs;
std::unordered_map<std::string, void*> kernels;
std::shared_mutex mtx;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for filename on macOS is 255
// characters, and on Windows the maximum length for whole path is 260. Clip
// file name with a little extra room and append a 16 character hash.
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id =
std::hash<std::string>{}(kernel_name.substr(max_file_name_length - 16));
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
auto output_dir = std::filesystem::temp_directory_path();
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
lib_exists = f.good();
}
if (!lib_exists) {
// Open source file and write source code to it
std::string source_file_name = kernel_file_name + ".cpp";
auto source_file_path = (output_dir / source_file_name).string();
std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();
try {
JitCompiler::exec(JitCompiler::build_command(
output_dir, source_file_name, shared_lib_name));
} catch (const std::exception& error) {
throw std::runtime_error(fmt::format(
"[Compile::eval_cpu] Failed to compile function {0}: {1}",
kernel_name,
error.what()));
}
}
// load library
cache.libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
<< kernel_name << std::endl
<< dlerror();
throw std::runtime_error(msg.str());
}
cache.kernels.insert({kernel_name, fun});
return fun;
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
#ifdef _MSC_VER
// Export the symbol
os << "__declspec(dllexport) ";
#endif
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
if (contiguous) {
os << " for (size_t i = 0; i < size; ++i) {" << std::endl;
} else {
for (int d = 0; d < ndim; ++d) {
os << " for (int i" << d << " = 0; i" << d << " < shape[" << d
<< "]; ++i" << d << ") {" << std::endl;
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[i];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *"
<< xname << ";" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
if (contiguous) {
os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
} else {
os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
}
// Close loops
if (contiguous) {
os << " }" << std::endl;
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
os << " " << xname << " += " << xname << "_strides[" << d << "];"
<< std::endl;
if (d < ndim - 1) {
os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]"
<< " * shape[" << d + 1 << "];" << std::endl;
}
}
os << " }" << std::endl;
}
}
// Finish the kernel
os << "}" << std::endl;
}
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
}
// Get the function
auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
build_kernel(
kernel,
kernel_name,
inputs_,
outputs_,
tape_,
constant_ids_,
contiguous,
ndim);
// Close extern "C"
kernel << "}" << std::endl;
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
for (auto& x : outputs) {
args.push_back(x.data<void>());
}
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
}
} // namespace mlx::core

View File

@@ -1,23 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is not available so check if the device is a GPU
// device.
namespace detail {
bool compile_available_for_device(const Device& device) {
return device == Device::gpu;
}
} // namespace detail
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error(
"[Compiled::eval_cpu] CPU compialtion not supported on the platform.");
}
} // namespace mlx::core

View File

@@ -1,12 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
// clang-format off
#include "mlx/types/half_types.h"
#include "mlx/types/complex.h"
#include "mlx/backend/common/unary_ops.h"
#include "mlx/backend/common/binary_ops.h"
// clang-format on
const char* get_kernel_preamble();

File diff suppressed because it is too large Load Diff

View File

@@ -1,315 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <numeric>
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) {
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
auto dst_ptr = dst.data<DstT>();
for (int i = 0; i < dst.size(); ++i) {
dst_ptr[i] = val;
}
}
template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT, int D>
inline void copy_dims(
const SrcT* src,
DstT* dst,
const Shape& shape,
const Strides& i_strides,
const Strides& o_strides,
int axis) {
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
copy_dims<SrcT, DstT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1);
} else {
*dst = static_cast<DstT>(*src);
}
src += stride_src;
dst += stride_dst;
}
}
template <typename SrcT, typename DstT>
void copy_general_general(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset) {
if (data_shape.empty()) {
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
auto dst_ptr = dst.data<DstT>() + o_offset;
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
}
}
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT>
void copy_general(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides(src.shape()),
0,
0);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
return;
case CopyType::Vector:
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
}
}
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
template <typename... Args>
inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
copy_inplace_dispatch(src, dst, ctype);
}
void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break;
case CopyType::Scalar:
case CopyType::General:
case CopyType::GeneralGeneral:
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
break;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_inplace(src, dst, ctype);
}
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core

View File

@@ -3,7 +3,6 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -23,17 +22,4 @@ enum class CopyType {
GeneralGeneral
};
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -1,119 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -1,40 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
namespace mlx::core {
/* Approximation to the inverse error function.
* Based on code from:
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
*/
float erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
} // namespace mlx::core

View File

@@ -1,87 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
std::vector<std::ptrdiff_t> strides_in(
in.strides().begin(), in.strides().end());
for (auto& s : strides_in) {
s *= in.itemsize();
}
std::vector<std::ptrdiff_t> strides_out(
out.strides().begin(), out.strides().end());
for (auto& s : strides_out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
} else {
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
}
float scale = 1.0f;
if (inverse_) {
size_t nelem = std::accumulate(
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
return x * shape[y];
});
scale /= nelem;
}
if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>();
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>();
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else {
throw std::runtime_error(
"[FFT] Received unexpected input and output type combination.");
}
}
} // namespace mlx::core

View File

@@ -1,20 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
template <typename T>
void matmul(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta);
} // namespace mlx::core

View File

@@ -1,157 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
#include "mlx/backend/common/gemm.h"
#include "mlx/backend/common/utils.h"
#include "mlx/dtype.h"
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
}
void matmul_bnns(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
/* bool a_is_weights = */ false,
/* bool b_is_weights = */ false,
/* BNNSNDArrayDescriptor iA_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, lda, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor iB_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, ldb, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor o_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{N, M, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, N, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
};
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
}
BNNSFilterDestroy(bnns_filter);
}
template <>
void matmul<float16_t>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
}
template <>
void matmul<bfloat16_t>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
}
} // namespace mlx::core

View File

@@ -1,44 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/gemm.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <>
void matmul<float>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,107 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
for (int i = 0; i < n / 2; i++) {
int k = i & (h - 1);
int j = ((i - k) << 1) + k;
float x = *(data_ptr + j);
float y = *(data_ptr + j + h);
*(data_ptr + j) = x + y;
*(data_ptr + j + h) = x - y;
if (h == n_over_2) {
*(data_ptr + j) *= scale;
*(data_ptr + j + h) *= scale;
}
}
h <<= 1;
}
}
}
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
auto end = matrix.find('\n', start);
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
for (int k = 0; k < m; k++) {
float x = *(data_ptr + i + k * n);
if (hmat_vec[k + j * m]) {
out[j] += x;
} else {
out[j] -= x;
}
}
}
for (int j = 0; j < m; j++) {
*(data_ptr + i + j * n) = out[j] * scale;
}
}
}
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
}
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
case float16:
return hadamard<float16_t>(out, n, m, scale_);
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}
}
} // namespace mlx::core

View File

@@ -1,674 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT>
void gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size
// If the array is col contiguous then the reverse is the case:
// - Any number of trailing ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size from the end
bool can_copy = false;
if (src.flags().row_contiguous) {
can_copy = true;
// Ignore leading 1s
int i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;
// Check the remaining
i++;
for (; i < src.ndim() && can_copy; ++i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
} else if (src.flags().col_contiguous) {
can_copy = true;
// Ignore trailing 1s
int i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;
// Skip the next slice size and check the remaining
i--;
for (; i >= 0 && can_copy; --i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
}
size_t slice_size = 1;
for (auto s : slice_sizes) {
slice_size *= s;
}
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) {
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
}
if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) {
std::copy(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
src_it.reset();
}
}
}
template <typename IdxT>
void dispatch_gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const Shape& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
break;
case uint8:
gather<uint8_t, IdxT>(src, inds, out, axes, size);
break;
case uint16:
gather<uint16_t, IdxT>(src, inds, out, axes, size);
break;
case uint32:
gather<uint32_t, IdxT>(src, inds, out, axes, size);
break;
case uint64:
gather<uint64_t, IdxT>(src, inds, out, axes, size);
break;
case int8:
gather<int8_t, IdxT>(src, inds, out, axes, size);
break;
case int16:
gather<int16_t, IdxT>(src, inds, out, axes, size);
break;
case int32:
gather<int32_t, IdxT>(src, inds, out, axes, size);
break;
case int64:
gather<int64_t, IdxT>(src, inds, out, axes, size);
break;
case float16:
gather<float16_t, IdxT>(src, inds, out, axes, size);
break;
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
case complex64:
gather<complex64_t, IdxT>(src, inds, out, axes, size);
break;
}
}
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end());
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
}
template <typename T, typename IdxT>
void gather_axis(
const array& src,
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
auto dst_ptr = out.data<T>();
auto ind_ax_stride = ind.strides(axis);
auto src_ax_stride = src.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto ind_ax_size = ind.shape(axis);
auto src_ax_size = src.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= ind.shape(i);
}
for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i);
}
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < ind_ax_size; ++j) {
auto ind_val = offset_neg_idx(
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
dst_ptr[k + j * dst_ax_stride] =
src_ptr[src_it.loc + ind_val * src_ax_stride];
}
ind_it.step();
src_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename IdxT>
void dispatch_gather_axis(
const array& src,
const array& inds,
array& out,
const int axis) {
switch (out.dtype()) {
case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis);
break;
case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
break;
case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
break;
case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
break;
case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
break;
case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis);
break;
case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis);
break;
case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis);
break;
case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis);
break;
case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis);
break;
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
break;
}
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
}
template <typename InT, typename IdxT, typename OpT>
void scatter(
const array& updates,
array& out,
const std::vector<array>& inds,
const std::vector<int>& axes,
const OpT& op) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < nind; ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
op(updates.data<InT>()[update_it.loc],
out.data<InT>() + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_inds(
array& out,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
switch (rtype) {
case Scatter::None:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
break;
case Scatter::Sum:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
break;
case Scatter::Prod:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
break;
case Scatter::Max:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y > x) ? *y : x;
});
break;
case Scatter::Min:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y < x) ? *y : x;
});
break;
}
}
template <typename InT>
void dispatch_scatter(
array& out,
const std::vector<array>& inds,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
case uint16:
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
break;
case uint32:
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
break;
case uint64:
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
break;
case int8:
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
break;
case int16:
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
break;
case int32:
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
break;
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
default:
throw std::runtime_error(
"[Scatter::eval_cpu] Cannot scatter with indices type.");
}
}
void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
switch (src.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
}
}
template <typename T, typename IdxT, typename OpT>
void scatter_axis(
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();
auto dst_ptr = out.data<T>();
auto idx_ax_stride = idx.strides(axis);
auto upd_ax_stride = upd.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= idx.shape(i);
}
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
upd_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_axis_op(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
break;
}
}
template <typename InT>
void dispatch_scatter_axis(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (idx.dtype()) {
case uint8:
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
break;
case uint16:
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
break;
case uint32:
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
break;
case uint64:
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
break;
case int8:
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
break;
case int16:
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
break;
case int32:
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
break;
case int64:
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
break;
default:
throw std::runtime_error(
"[ScatterAxis::eval_cpu] Cannot scatter with indices type.");
}
}
void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
auto& idx = inputs[1];
auto& updates = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
switch (src.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
}
} // namespace mlx::core

View File

@@ -1,120 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
return info;
}
namespace mlx::core {
void general_inv(array& inv, int N, int i) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
for (int i = 0; i < num_matrices; i++) {
if (tri) {
tri_inv(inv, N, i, upper);
} else {
general_inv(inv, N, i);
}
}
}
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output, tri_, upper_);
}
} // namespace mlx::core

View File

@@ -1,152 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/jit_compiler.h"
#include <sstream>
#include <vector>
#include <fmt/format.h>
namespace mlx::core {
#ifdef _MSC_VER
namespace {
// Split string into array.
std::vector<std::string> str_split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(str);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}
// Get path information about MSVC.
struct VisualStudioInfo {
VisualStudioInfo() {
#ifdef _M_ARM64
arch = "arm64";
#else
arch = "x64";
#endif
// Get path of Visual Studio.
std::string vs_path = JitCompiler::exec(fmt::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
std::getenv("ProgramFiles(x86)")));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
// Read the envs from vcvarsall.
std::string envs = JitCompiler::exec(fmt::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path,
arch));
for (const std::string& line : str_split(envs, '\n')) {
// Each line is in the format "ENV_NAME=values".
auto pos = line.find_first_of('=');
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
continue;
std::string name = line.substr(0, pos);
std::string value = line.substr(pos + 1);
if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}
}
std::string arch;
std::string cl_exe;
std::vector<std::string> libpaths;
};
const VisualStudioInfo& GetVisualStudioInfo() {
static VisualStudioInfo info;
return info;
}
} // namespace
#endif // _MSC_VER
std::string JitCompiler::build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name) {
#ifdef _MSC_VER
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
}
return fmt::format(
"\""
"cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
"/link /out:\"{3}\" {4} 2>&1"
"\"",
dir.string(),
info.cl_exe,
source_file_name,
shared_lib_name,
libpaths);
#else
return fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared \"{0}\" -o \"{1}\" 2>&1",
(dir / source_file_name).string(),
(dir / shared_lib_name).string());
#endif
}
std::string JitCompiler::exec(const std::string& cmd) {
#ifdef _MSC_VER
FILE* pipe = _popen(cmd.c_str(), "r");
#else
FILE* pipe = popen(cmd.c_str(), "r");
#endif
if (!pipe) {
throw std::runtime_error("popen() failed.");
}
char buffer[128];
std::string ret;
while (fgets(buffer, sizeof(buffer), pipe)) {
ret += buffer;
}
// Trim trailing spaces.
ret.erase(
std::find_if(
ret.rbegin(),
ret.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
ret.end());
#ifdef _MSC_VER
int status = _pclose(pipe);
#else
int status = pclose(pipe);
#endif
if (status == -1) {
throw std::runtime_error("pclose() failed.");
}
#if defined(_WIN32) || defined(__FreeBSD__)
int code = status;
#else
int code = WEXITSTATUS(status);
#endif
if (code != 0) {
throw std::runtime_error(fmt::format(
"Failed to execute command with return code {0}: \"{1}\", "
"the output is: {2}",
code,
cmd,
ret));
}
return ret;
}
} // namespace mlx::core

View File

@@ -1,20 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <filesystem>
namespace mlx::core {
class JitCompiler {
public:
// Build a shell command that compiles a source code file to a shared library.
static std::string build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name);
// Run a command and get its output.
static std::string exec(const std::string& cmd);
};
} // namespace mlx::core

View File

@@ -1,33 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex>
#define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double>
#endif
#ifdef MLX_USE_ACCELERATE
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#include <lapack.h>
#endif
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
// This is to work around a change in the function signatures of lapack >= 3.9.1
// where functions taking char* also include a strlen argument, see a similar
// change in OpenCV:
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
#define MLX_LAPACK_FUNC(f) LAPACK_##f
#else
#define MLX_LAPACK_FUNC(f) f##_
#endif

View File

@@ -1,38 +0,0 @@
# This script generates a C++ function that provides the CPU
# code for use with kernel generation.
#
# Copyright © 2024 Apple Inc.
$OUTPUT_FILE = $args[0]
$CL = $args[1]
$SRCDIR = $args[2]
# Get command result as array.
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
# Remove empty lines.
# Otherwise there will be too much empty lines making the result unreadable.
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
# Concatenate to string.
$CONTENT = $CONTENT -join "`n"
# Append extra content.
$CONTENT = @"
$($CONTENT)
using namespace mlx::core;
using namespace mlx::core::detail;
"@
# Convert each char to ASCII code.
# Unlike the unix script that outputs string literal directly, the output from
# MSVC is way too large to be embedded as string and compilation will fail, so
# we store it as static array instead.
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
$OUTPUT = @"
const char* get_kernel_preamble() {
static char preamble[] = { $CHARCODES };
return preamble;
}
"@
Set-Content -Path $OUTPUT_FILE -Value $OUTPUT

View File

@@ -1,38 +0,0 @@
#!/bin/bash
#
# This script generates a C++ function that provides the CPU
# code for use with kernel generation.
#
# Copyright © 2023-24 Apple Inc.
OUTPUT_FILE=$1
GCC=$2
SRCDIR=$3
CLANG=$4
ARCH=$5
if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
EOM
CC_FLAGS="-arch ${ARCH}"
else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core;
using namespace mlx::core::detail;
)preamble";
}
EOF

View File

@@ -1,300 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename mask_t>
inline void mask_matrix(
T* data,
const mask_t* mask,
int block_size,
const int X,
const int Y,
const int64_t X_data_str,
const int64_t Y_data_str,
const int64_t X_mask_str,
const int64_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
if (do_mask != 1) {
int loc_x = i * block_size;
int loc_y = j * block_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(block_size, X - loc_x);
int size_y = std::min(block_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
if constexpr (std::is_same_v<mask_t, bool>) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
} else {
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
}
}
}
}
}
}
}
} // namespace
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
bool has_op_mask = inputs.size() > 3;
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
auto mask_array = [](const array& mask,
float* data,
int block_size,
int batch_idx,
int X,
int Y,
size_t X_data_str,
size_t Y_data_str) {
auto mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
auto X_mask_str = mask.strides()[mask.ndim() - 2];
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
if (mask.dtype() == bool_) {
return mask_matrix(
data,
mask.data<bool>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
} else {
return mask_matrix(
data,
mask.data<float>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
}
};
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
float* bi =
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
float* ci = out.data<float>() + M * N * i;
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
mask_array(
a_mask,
ai,
block_size_,
i,
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
block_size_,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1);
}
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
out.shape(-1) // ldc
);
// Zero out blocks in out
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
}
}
}
void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());
auto batch_shape_B = get_batch_dims(b.shape());
auto batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -1,79 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/gemm.h"
#include "mlx/primitives.h"
namespace mlx::core {
void matmul_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (out.dtype() == float32) {
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else if (out.dtype() == float16) {
matmul<float16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else if (out.dtype() == bfloat16) {
matmul<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
}
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (inputs[0].shape(-1) == 0) {
std::memset(out.data<void>(), 0, out.nbytes());
return;
}
return matmul_general(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype);
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -1,411 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
int64_t compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes) {
auto compute_offset = [&strides, &axes](const auto* indices) {
int64_t offset = 0;
for (int i = 0; i < axes.size(); ++i) {
offset += indices[i] * strides[axes[i]];
}
return offset;
};
switch (indices.dtype()) {
case int8:
case uint8:
return compute_offset(indices.data<uint8_t>());
case int16:
case uint16:
return compute_offset(indices.data<uint16_t>());
case int32:
case uint32:
return compute_offset(indices.data<uint32_t>());
case int64:
case uint64:
return compute_offset(indices.data<uint64_t>());
default:
throw std::runtime_error("Invalid indices type.");
}
}
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Copy::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomTransforms::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void ExpandDims::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Split::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Squeeze::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_);
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
}
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy(in, out, ctype);
}
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
}
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
// Copy or move src to dst
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@@ -1,161 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = M;
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {});
auto flags = in.flags();
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
auto strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
lpack<T>::xgeqrf(
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
lpack<T>::xgeqrf(
&M,
&N,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] = 0;
}
for (int k = j; k < r.shape(-1); ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] =
in.data<T>()[i * N * M + j + k * M];
}
}
}
// Get work size
lwork = -1;
lpack<T>::xorgqr(
&M,
&num_reflectors,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
lpack<T>::xorgqr(
&M,
&num_reflectors,
&num_reflectors,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
q.set_data(allocator::malloc_or_wait(q.nbytes()));
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < q.shape(-2); ++j) {
for (int k = 0; k < q.shape(-1); ++k) {
q.data<T>()[i * M * num_reflectors + j * num_reflectors + k] =
in.data<T>()[i * N * M + j + k * M];
}
}
}
// Cleanup
allocator::free(work);
allocator::free(tau);
}
void QRF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
}
} // namespace mlx::core

View File

@@ -1,642 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size>
void _qmm(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales;
const T* biases_local = biases;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) += xi * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
}
}
}
}
result += N;
}
}
template <typename T, int bits, int group_size>
void _qmm_t(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += x_local[p] * (scale * wl[p] + bias);
}
w_local += bytes_per_pack;
x_local += pack_factor;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum +=
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
}
}
}
*result = sum;
result++;
}
x += K;
}
}
template <int bits, int S>
simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
constexpr int bitmask = (1 << bits) - 1;
simd::Simd<uint32_t, S> wi;
if constexpr (bits == 4 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & bitmask;
} else if constexpr (bits == 8 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto l = simd::Simd<uint32_t, 4>(*w++);
auto r = simd::Simd<uint32_t, 4>(*w);
wi = simd::Simd<uint32_t, S>(l, r);
wi = wi >> shifts;
wi = wi & bitmask;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
return wi;
}
template <typename T, int bits, int group_size>
void _qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));
w_local += packs_per_simd;
wf = wf * scale;
wf = wf + bias;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
acc = acc + x_simd * wf;
x_local += S;
}
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {
_qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
_qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
} else {
_qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T>
void _qmm_dispatch_typed(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
int bits,
bool transposed_w) {
switch (bits) {
case 2:
_qmm_dispatch_group<T, 2>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 3:
_qmm_dispatch_group<T, 3>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 4:
_qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6:
_qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 8:
_qmm_dispatch_group<T, 8>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
default:
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
}
}
void _qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float>() + elem_to_loc(i * g_els, scales),
biases.data<float>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
void _bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
}
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core

View File

@@ -1,377 +1,147 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <functional>
#include <limits>
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename U>
struct Limits {
static const U max;
static const U min;
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr type max = std::numeric_limits<type>::max(); \
static constexpr type min = std::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static const type max; \
static const type min; \
};
instantiate_float_limit(float16_t);
instantiate_float_limit(bfloat16_t);
instantiate_float_limit(float);
instantiate_float_limit(complex64_t);
template <>
struct Limits<bool> {
static constexpr bool max = true;
static constexpr bool min = false;
};
const float Limits<float>::max = std::numeric_limits<float>::infinity();
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::max =
std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::min =
-std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::min =
-std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::max =
std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::min =
-std::numeric_limits<float>::infinity();
struct AndReduce {
template <typename T>
bool operator()(bool x, T y) {
return x & (y != 0);
}
bool operator()(bool x, bool y) {
return x & y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x & (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x & y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::all(x);
};
};
struct OrReduce {
template <typename T>
bool operator()(bool x, T y) {
return x | (y != 0);
}
bool operator()(bool x, bool y) {
return x | y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x | (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x | y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::any(x);
};
};
struct MaxReduce {
template <typename T>
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::maximum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
};
struct MinReduce {
template <typename T>
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::minimum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
};
struct SumReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x + y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return y + x;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::sum(x);
};
};
struct ProdReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x * y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return x * y;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::prod(x);
};
};
template <typename InT>
void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
}
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
} // namespace
void nd_loop(
std::function<void(int)> callback,
const Shape& shape,
const Strides& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, int64_t>> reductions;
for (auto a : axes) {
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
bool a_is_zero = a.second == 0;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
Shape shape;
Strides strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int64_t size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
auto stride_i = x.strides()[i];
auto shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break;
}
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
}
// In the case of an expanded dimension we are being conservative and
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
} // namespace mlx::core

View File

@@ -2,7 +2,6 @@
#pragma once
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -49,193 +48,8 @@ struct ReductionPlan {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Helper for the ndimensional strided loop
// Should this be in utils?
void nd_loop(
std::function<void(int)> callback,
const Shape& shape,
const Strides& strides);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
template <typename T, typename U, typename Op>
void strided_reduce(
const T* x,
U* accumulator,
int size,
size_t stride,
Op op) {
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
for (int i = 0; i < size; i++) {
U* moving_accumulator = accumulator;
auto s = stride;
while (s >= N) {
auto acc = simd::load<U, N>(moving_accumulator);
auto v = simd::Simd<U, N>(simd::load<T, N>(x));
simd::store<U, N>(moving_accumulator, op(acc, v));
moving_accumulator += N;
x += N;
s -= N;
}
while (s-- > 0) {
*moving_accumulator = op(*moving_accumulator, *x);
moving_accumulator++;
x++;
}
}
};
template <typename T, typename U, typename Op>
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
simd::Simd<U, N> accumulator_v(init);
while (size >= N) {
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
x += N;
size -= N;
}
*accumulator = op(*accumulator, op(accumulator_v));
while (size-- > 0) {
*accumulator = op(*accumulator, *x);
x++;
}
}
template <typename T, typename U, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Op op) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
}
return;
}
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
int reduction_size = plan.shape.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
op,
init);
},
plan.shape,
plan.strides);
}
}
return;
}
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
x_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
return;
}
if (plan.type == GeneralStridedReduce ||
plan.type == ContiguousStridedReduce) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
op);
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
return;
}
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = op(val, *(x_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
}
}
} // namespace mlx::core

View File

@@ -1,147 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/reduce.h"
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, int64_t>> reductions;
for (auto a : axes) {
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
bool a_is_zero = a.second == 0;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
Shape shape;
Strides strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int64_t size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
auto stride_i = x.strides()[i];
auto shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break;
}
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
}
// In the case of an expanded dimension we are being conservative and
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
} // namespace mlx::core

View File

@@ -1,312 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void contiguous_scan(
const T* input,
U* output,
int count,
int stride,
bool reverse,
bool inclusive,
const Op& op,
U init) {
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
*output = *input;
for (int j = 1; j < stride; j++) {
input++;
output++;
*output = op(*(output - 1), *input);
}
output++;
input++;
}
} else {
for (int i = 0; i < count; i++) {
*output = init;
for (int j = 1; j < stride; j++) {
*(output + 1) = op(*output, *input);
input++;
output++;
}
output++;
input++;
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = *input;
for (int j = 1; j < stride; j++) {
input--;
output--;
*output = op(*(output + 1), *input);
}
output += stride;
input += stride;
}
} else {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = init;
for (int j = 1; j < stride; j++) {
*(output - 1) = op(*output, *input);
input--;
output--;
}
output += stride;
input += stride;
}
}
}
};
template <typename T, typename U, typename Op>
void strided_scan(
const T* input,
U* output,
int count,
int size,
int stride,
bool reverse,
bool inclusive,
const Op& op,
U init) {
// TODO: Vectorize the following naive implementation
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
std::copy(input, input + stride, output);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
*output = op(*(output - stride), *input);
output++;
input++;
}
}
}
} else {
for (int i = 0; i < count; i++) {
std::fill(output, output + stride, init);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
*output = op(*(output - stride), *(input - stride));
output++;
input++;
}
}
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::copy(input, input + stride, output);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
*output = op(*(output + stride), *input);
}
}
output += size * stride;
input += size * stride;
}
} else {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::fill(output, output + stride, init);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
*output = op(*(output + stride), *(input + stride));
}
}
output += size * stride;
input += size * stride;
}
}
}
};
template <typename T, typename U, typename Op>
void scan_op(
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive,
const Op& op,
U init) {
output.set_data(allocator::malloc_or_wait(output.nbytes()));
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
contiguous_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis),
input.shape(axis),
reverse,
inclusive,
op,
init);
} else {
strided_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis],
input.shape(axis),
input.strides()[axis],
reverse,
inclusive,
op,
init);
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
}
}
template <typename T, typename U>
void scan_dispatch(
Scan::ReduceType rtype,
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Prod: {
auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Min: {
auto op = [](U y, T x) { return x < y ? x : y; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Max: {
auto op = [](U y, T x) { return x < y ? y : x; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
}
}
} // namespace
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General);
in = arr_copy;
}
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
}
} // namespace mlx::core

View File

@@ -1,73 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/binary_ops.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename Op>
void select_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
break;
}
}
} // namespace
void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select());
}
} // namespace mlx::core

View File

@@ -1,56 +0,0 @@
#pragma once
#include "mlx/backend/common/simd/base_simd.h"
#if MLX_SIMD_LIBRARY_VERSION < 6
#include "mlx/backend/common/simd/neon_fp16_simd.h"
#endif
namespace mlx::core::simd {
#if MLX_SIMD_LIBRARY_VERSION >= 6
constexpr int N = 8;
template <int N>
struct ScalarT<float16_t, N> {
using v = _Float16;
};
#endif
template <>
static constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \
inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \
Simd<float, N> in = v; \
return op(in); \
}
SIMD_FP16_DEFAULT_UNARY(acos)
SIMD_FP16_DEFAULT_UNARY(acosh)
SIMD_FP16_DEFAULT_UNARY(asin)
SIMD_FP16_DEFAULT_UNARY(asinh)
SIMD_FP16_DEFAULT_UNARY(atan)
SIMD_FP16_DEFAULT_UNARY(atanh)
SIMD_FP16_DEFAULT_UNARY(cosh)
SIMD_FP16_DEFAULT_UNARY(expm1)
SIMD_FP16_DEFAULT_UNARY(log)
SIMD_FP16_DEFAULT_UNARY(log2)
SIMD_FP16_DEFAULT_UNARY(log10)
SIMD_FP16_DEFAULT_UNARY(log1p)
SIMD_FP16_DEFAULT_UNARY(sinh)
SIMD_FP16_DEFAULT_UNARY(tan)
SIMD_FP16_DEFAULT_UNARY(tanh)
#define SIMD_FP16_DEFAULT_BINARY(op) \
template <> \
inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \
Simd<float, N> a = x; \
Simd<float, N> b = y; \
return op(a, b); \
}
SIMD_FP16_DEFAULT_BINARY(atan2)
SIMD_FP16_DEFAULT_BINARY(remainder)
SIMD_FP16_DEFAULT_BINARY(pow)
} // namespace mlx::core::simd

View File

@@ -1,303 +0,0 @@
#pragma once
#include <simd/math.h>
#include <simd/vector.h>
#include <stdint.h>
#include <cmath>
#include <complex>
#include "mlx/backend/common/simd/base_simd.h"
// There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15
#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
__TV_OS_VERSION_MIN_REQUIRED >= 180000
#define MLX_SIMD_LIBRARY_VERSION 6
#else
#define MLX_SIMD_LIBRARY_VERSION 5
#endif
namespace mlx::core::simd {
// Apple simd namespace
namespace asd = ::simd;
// This indirection is needed to remap certain types to ones that accelerate
// SIMD can handle
template <typename T, int N>
struct ScalarT {
using v = T;
};
template <int N>
struct ScalarT<bool, N> {
using v = char;
};
template <int N>
struct ScalarT<int8_t, N> {
using v = char;
};
template <int N>
struct ScalarT<uint64_t, N> {
using v = unsigned long;
};
template <int N>
struct ScalarT<int64_t, N> {
using v = long;
};
template <typename T, int N>
struct Simd {
static constexpr int size = N;
using scalar_t = typename ScalarT<T, N>::v;
Simd<T, N>() {}
template <typename U>
Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
template <typename U>
Simd<T, N>(U v) : value(v){};
Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {
value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
x.value, y.value);
};
T operator[](int idx) const {
return reinterpret_cast<const T*>(&value)[idx];
}
T& operator[](int idx) {
return reinterpret_cast<T*>(&value)[idx];
}
typename asd::Vector<scalar_t, N>::packed_t value;
};
// Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally
template <>
static constexpr int max_size<int8_t> = 16;
template <>
static constexpr int max_size<int16_t> = 16;
template <>
static constexpr int max_size<int> = 8;
template <>
static constexpr int max_size<int64_t> = 4;
template <>
static constexpr int max_size<uint8_t> = 16;
template <>
static constexpr int max_size<uint16_t> = 16;
template <>
static constexpr int max_size<uint32_t> = 8;
template <>
static constexpr int max_size<uint64_t> = 4;
template <>
static constexpr int max_size<float> = 8;
template <>
static constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \
Simd<T, N> name(Simd<T, N> v) { \
return op(v.value); \
}
SIMD_DEFAULT_UNARY(abs, asd::abs)
SIMD_DEFAULT_UNARY(floor, asd::floor)
SIMD_DEFAULT_UNARY(acos, asd::acos)
SIMD_DEFAULT_UNARY(acosh, asd::acosh)
SIMD_DEFAULT_UNARY(asin, asd::asin)
SIMD_DEFAULT_UNARY(asinh, asd::asinh)
SIMD_DEFAULT_UNARY(atan, asd::atan)
SIMD_DEFAULT_UNARY(atanh, asd::atanh)
SIMD_DEFAULT_UNARY(ceil, asd::ceil)
SIMD_DEFAULT_UNARY(cosh, asd::cosh)
SIMD_DEFAULT_UNARY(expm1, asd::expm1)
SIMD_DEFAULT_UNARY(log, asd::log)
SIMD_DEFAULT_UNARY(log2, asd::log2)
SIMD_DEFAULT_UNARY(log10, asd::log10)
SIMD_DEFAULT_UNARY(log1p, asd::log1p)
SIMD_DEFAULT_UNARY(rint, asd::rint)
SIMD_DEFAULT_UNARY(sinh, asd::sinh)
SIMD_DEFAULT_UNARY(sqrt, asd::sqrt)
SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)
SIMD_DEFAULT_UNARY(recip, asd::recip)
SIMD_DEFAULT_UNARY(tan, asd::tan)
SIMD_DEFAULT_UNARY(tanh, asd::tanh)
template <typename T, int N>
Simd<T, N> operator-(Simd<T, N> v) {
return -v.value;
}
template <typename T, int N>
Simd<bool, N> isnan(Simd<T, N> v) {
return asd::convert<char>(v.value != v.value);
}
// No simd_boolN in accelerate, use int8_t instead
template <typename T, int N>
Simd<bool, N> operator!(Simd<T, N> v) {
return asd::convert<char>(!v.value);
}
#define SIMD_DEFAULT_BINARY(OP) \
template <typename T, typename U, int N> \
Simd<T, N> operator OP(Simd<T, N> x, U y) { \
return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
} \
template <typename T1, typename T2, int N> \
Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
} \
template <typename T1, typename T2, int N> \
Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
}
SIMD_DEFAULT_BINARY(+)
SIMD_DEFAULT_BINARY(-)
SIMD_DEFAULT_BINARY(/)
SIMD_DEFAULT_BINARY(*)
SIMD_DEFAULT_BINARY(<<)
SIMD_DEFAULT_BINARY(>>)
SIMD_DEFAULT_BINARY(|)
SIMD_DEFAULT_BINARY(^)
SIMD_DEFAULT_BINARY(&)
SIMD_DEFAULT_BINARY(&&)
SIMD_DEFAULT_BINARY(||)
#define SIMD_DEFAULT_COMPARISONS(OP) \
template <int N, typename T, typename U> \
Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
return asd::convert<char>(a.value OP b); \
} \
template <int N, typename T, typename U> \
Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
return asd::convert<char>(a OP b.value); \
} \
template <int N, typename T1, typename T2> \
Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
return asd::convert<char>(a.value OP b.value); \
}
SIMD_DEFAULT_COMPARISONS(>)
SIMD_DEFAULT_COMPARISONS(<)
SIMD_DEFAULT_COMPARISONS(>=)
SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::max(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::min(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
Simd<T, N> r;
if constexpr (!std::is_integral_v<T>) {
r = asd::remainder(a.value, b.value);
} else {
r = a - b * (a / b);
}
if constexpr (std::is_signed_v<T>) {
auto mask = r != 0 && (r < 0 != b < 0);
r = select(mask, r + b, r);
}
return r;
}
template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) {
return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
} else if constexpr (sizeof(T1) == 4) {
return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
} else {
return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
}
}
template <typename T, int N>
Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
if constexpr (!std::is_integral_v<T>) {
return asd::pow(base.value, exp.value);
} else {
Simd<T, N> res = 1;
while (any(exp)) {
res = select(exp & 1, res * base, res);
base = select(exp, base * base, base);
exp = exp >> 1;
}
return res;
}
}
template <typename T, int N>
Simd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {
return asd::clamp(v.value, min.value, max.value);
}
template <typename T, typename U, int N>
Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
}
// Reductions
template <typename T, int N>
bool all(Simd<T, N> x) {
return asd::all(x.value);
}
template <typename T, int N>
bool any(Simd<T, N> x) {
return asd::any(x.value);
}
template <typename T, int N>
T sum(Simd<T, N> x) {
return asd::reduce_add(x.value);
}
template <typename T, int N>
T max(Simd<T, N> x) {
return asd::reduce_max(x.value);
}
template <typename T, int N>
T min(Simd<T, N> x) {
return asd::reduce_min(x.value);
}
template <typename T, int N>
T prod(Simd<T, N> x) {
auto ptr = (T*)&x;
auto lhs = load<T, N / 2>(ptr);
auto rhs = load<T, N / 2>(ptr + N / 2);
return prod(lhs * rhs);
}
} // namespace mlx::core::simd
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "mlx/backend/common/simd/accelerate_fp16_simd.h"
#endif

View File

@@ -1,253 +0,0 @@
#pragma once
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <complex>
namespace mlx::core::simd {
template <typename T, int N>
struct Simd;
template <typename T>
static constexpr int max_size = 1;
template <typename T>
struct Simd<T, 1> {
static constexpr int size = 1;
T value;
Simd() {}
template <typename U>
Simd(Simd<U, 1> v) : value(v.value) {}
template <typename U>
Simd(U v) : value(v) {}
};
template <typename T, int N>
Simd<T, N> load(const T* x) {
return *(Simd<T, N>*)x;
}
template <typename T, int N>
void store(T* dst, Simd<T, N> x) {
// Maintain invariant that bool is either 0 or 1 as
// simd comparison ops set all bits in the result to 1
if constexpr (std::is_same_v<T, bool> && N > 1) {
x = x & 1;
}
*(Simd<T, N>*)dst = x;
}
template <typename, typename = void>
constexpr bool is_complex = false;
template <typename T>
constexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =
true;
template <typename T>
Simd<T, 1> rint(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
return Simd<T, 1>{
T{std::rint(in.value.real()), std::rint(in.value.imag())}};
} else {
return Simd<T, 1>{std::rint(in.value)};
}
}
template <typename T>
Simd<T, 1> rsqrt(Simd<T, 1> in) {
return T(1.0) / sqrt(in);
}
template <typename T>
Simd<T, 1> recip(Simd<T, 1> in) {
return T(1.0) / in;
}
#define DEFAULT_UNARY(name, op) \
template <typename T> \
Simd<T, 1> name(Simd<T, 1> in) { \
return op(in.value); \
}
DEFAULT_UNARY(operator-, std::negate{})
DEFAULT_UNARY(operator!, std::logical_not{})
DEFAULT_UNARY(abs, std::abs)
DEFAULT_UNARY(acos, std::acos)
DEFAULT_UNARY(acosh, std::acosh)
DEFAULT_UNARY(asin, std::asin)
DEFAULT_UNARY(asinh, std::asinh)
DEFAULT_UNARY(atan, std::atan)
DEFAULT_UNARY(atanh, std::atanh)
DEFAULT_UNARY(ceil, std::ceil)
DEFAULT_UNARY(conj, std::conj)
DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
return std::real(in.value);
}
template <typename T>
auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
return std::imag(in.value);
}
template <typename T>
Simd<bool, 1> isnan(Simd<T, 1> in) {
return std::isnan(in.value);
}
#define DEFAULT_BINARY(OP) \
template <typename T1, typename T2> \
auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
->Simd<decltype(a.value OP b.value), 1> { \
return a.value OP b.value; \
} \
template <typename T1, typename T2> \
auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
return a OP b.value; \
} \
template <typename T1, typename T2> \
auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
return a.value OP b; \
}
DEFAULT_BINARY(+)
DEFAULT_BINARY(-)
DEFAULT_BINARY(*)
DEFAULT_BINARY(/)
DEFAULT_BINARY(<<)
DEFAULT_BINARY(>>)
DEFAULT_BINARY(|)
DEFAULT_BINARY(^)
DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
T r;
if constexpr (std::is_integral_v<T>) {
r = a % b;
} else {
r = std::remainder(a, b);
}
if constexpr (std::is_signed_v<T>) {
if (r != 0 && (r < 0 != b < 0)) {
r += b;
}
}
return r;
}
template <typename T>
Simd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
if constexpr (!std::is_integral_v<T>) {
if (std::isnan(a)) {
return a;
}
}
return (a > b) ? a : b;
}
template <typename T>
Simd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
if constexpr (!std::is_integral_v<T>) {
if (std::isnan(a)) {
return a;
}
}
return (a < b) ? a : b;
}
template <typename T>
Simd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {
T base = a.value;
T exp = b.value;
if constexpr (!std::is_integral_v<T>) {
return std::pow(base, exp);
} else {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
}
template <typename T>
Simd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {
return std::atan2(a.value, b.value);
}
#define DEFAULT_COMPARISONS(OP) \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
return a.value OP b.value; \
} \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
return a OP b.value; \
} \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
return a.value OP b; \
}
DEFAULT_COMPARISONS(>)
DEFAULT_COMPARISONS(<)
DEFAULT_COMPARISONS(>=)
DEFAULT_COMPARISONS(<=)
DEFAULT_COMPARISONS(==)
DEFAULT_COMPARISONS(!=)
template <typename MaskT, typename T>
Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
return mask.value ? x.value : y.value;
}
template <typename T>
Simd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {
return std::clamp(v.value, min.value, max.value);
}
template <typename T, typename U>
Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
return std::fma(x.value, y.value, Simd<T, 1>(z).value);
}
// Reductions
#define DEFAULT_REDUCTION(name, type) \
template <typename T> \
type name(Simd<T, 1> x) { \
return x.value; \
}
DEFAULT_REDUCTION(max, T)
DEFAULT_REDUCTION(min, T)
DEFAULT_REDUCTION(sum, T)
DEFAULT_REDUCTION(prod, T)
DEFAULT_REDUCTION(any, bool)
DEFAULT_REDUCTION(all, bool)
} // namespace mlx::core::simd

View File

@@ -1,193 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/common/simd/type.h"
namespace mlx::core::simd {
constexpr float inf = std::numeric_limits<float>::infinity();
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
template <typename T, int N>
Simd<T, N> exp(Simd<T, N> in) {
if constexpr (is_complex<T>) {
return Simd<T, 1>{std::exp(in.value)};
} else {
Simd<float, N> x_init = in;
auto x = x_init * 1.442695f; // multiply with log_2(e)
Simd<float, N> ipart, fpart;
ipart = floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = fma(x, fpart, 1.339887440266574e-3f);
x = fma(x, fpart, 9.618437357674640e-3f);
x = fma(x, fpart, 5.550332471162809e-2f);
x = fma(x, fpart, 2.402264791363012e-1f);
x = fma(x, fpart, 6.931472028550421e-1f);
x = fma(x, fpart, 1.000000000000000f);
// generate 2**ipart in the floating point representation using integer
// bitshifting
Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;
// Deal with NaN and Inf
auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);
result = select(x_init > 88.0f, Simd<float, N>(inf), result);
result = select(x_init < -88.0f, Simd<float, N>(0), result);
return Simd<T, N>(result);
}
}
/* Implementation from:
* https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357
* which originally came from the Cephes math library.
*/
template <bool Sine, typename T, int N>
Simd<T, N> sincos(Simd<T, N> in) {
auto sign_mask_sin = in < 0;
in = abs(in);
Simd<float, N> x = in;
// scale by 4/Pi
auto y = x * 1.27323954473516f;
// store the integer part of y in mm0
Simd<uint32_t, N> emm2 = y;
// j=(j+1) & (~1) (see the cephes sources)
emm2 = emm2 + 1;
emm2 = emm2 & ~1;
y = emm2;
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
auto poly_mask = (emm2 & 2) != 0;
// The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3
x = fma(y, Simd<float, N>(-0.78515625f), x);
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2
auto z = x * x;
auto y1 =
fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);
auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);
y1 = fma(y1, z, 4.166664568298827e-2f);
y2 = fma(y2, z, -1.6666654611e-1f);
y1 = y1 * z;
y2 = y2 * z;
y1 = y1 * z;
y2 = fma(x, y2, x);
y1 = fma(z, Simd<float, N>(-0.5f), y1);
y1 = y1 + 1.0f;
if constexpr (Sine) {
auto ys = select(poly_mask, y1, y2);
return select(sign_mask_sin, -ys, ys);
} else {
auto yc = select(poly_mask, y2, y1);
return select(sign_mask_cos, yc, -yc);
}
}
template <typename T, int N>
Simd<T, N> sin(Simd<T, N> x) {
if constexpr (is_complex<T>) {
return std::sin(x.value);
} else {
return sincos<true>(x);
}
}
template <typename T, int N>
Simd<T, N> cos(Simd<T, N> x) {
if constexpr (is_complex<T>) {
return std::cos(x.value);
} else {
return sincos<false>(x);
}
}
template <typename T, int N>
Simd<T, N> erf(Simd<T, N> x) {
// https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175
Simd<float, N> v = x;
auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));
auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);
r = fma(r, t, 1.421413741f);
r = fma(r, t, -0.284496736f);
r = fma(r, t, 0.254829592f);
auto e = -exp(-v * v);
auto result = Simd<T, N>(fma(e * t, r, 1.0f));
return select(x > 0, result, -result);
}
template <typename T, int N>
Simd<T, N> erfinv(Simd<T, N> a_) {
Simd<float, N> a = a_;
auto t = fma(a, 0.0f - a, 1.0f);
t = log(t);
auto lhs = [](auto t) {
Simd<float, N> p;
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
};
auto rhs = [](auto t) {
Simd<float, N> p;
p = 5.43877832e-9f; // 0x1.75c000p-28
p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
};
auto thresh = 6.125f;
// Compute both branches and select if N > 1
if constexpr (N == 1) {
if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793
return a * lhs(t);
} else { // maximum ulp error = 2.35002
return a * rhs(t);
}
} else {
return a * select(t > thresh, lhs(t), rhs(t));
}
}
} // namespace mlx::core::simd

View File

@@ -1,212 +0,0 @@
#pragma once
#include <arm_neon.h>
#include "mlx/backend/common/simd/base_simd.h"
namespace mlx::core::simd {
constexpr int N = 8;
template <>
struct Simd<float16_t, N> {
static constexpr int size = N;
using scalar_t = float16_t;
Simd<float16_t, N>() {}
template <typename U>
Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
Simd<float16_t, N>(float16x8_t v) : value(v){};
Simd<float16_t, N>(Simd<float, N> other) {
auto f32x4_a = *(float32x4_t*)(&other);
auto f32x4_b = *((float32x4_t*)(&other) + 1);
value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
};
Simd<float16_t, N>(Simd<uint16_t, N> other) {
value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
};
operator Simd<int16_t, N>() {
auto v = vcvtq_s16_f16(value);
return load<int16_t, N>((int16_t*)&v);
};
operator Simd<float, N>() {
float32x4x2_t v;
v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
v.val[1] = vcvt_high_f32_f16(value);
return load<float, N>((float*)&v);
}
float16_t operator[](int idx) const {
return reinterpret_cast<const float16_t*>(&value)[idx];
}
float16_t& operator[](int idx) {
return reinterpret_cast<float16_t*>(&value)[idx];
}
float16x8_t value;
};
#define DEFINE_NEON_UNARY_OP(name, op) \
inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
return Simd<float16_t, N>{op(a.value)}; \
}
DEFINE_NEON_UNARY_OP(abs, vabsq_f16)
DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)
DEFINE_NEON_UNARY_OP(floor, vrndmq_f16)
DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)
DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)
DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)
DEFINE_NEON_UNARY_OP(rint, vrndnq_f16)
#define DEFINE_NEON_BINARY_OP(name, op) \
inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
return op(a.value, b.value); \
} \
template <typename T> \
Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
return op(a.value, Simd<float16_t, N>(b).value); \
} \
template <typename T> \
Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
return op(Simd<float16_t, N>(a).value, b.value); \
}
inline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {
auto out = vceqzq_f16(v.value);
return Simd<uint16_t, N>(*(uint16_t*)&out);
}
inline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {
return vnegq_f16(v.value);
}
DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)
DEFINE_NEON_BINARY_OP(minimum, vminq_f16)
DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
#define DEFINE_NEON_COMPARISON(Op, op) \
template <typename T> \
Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
auto out = op(a.value, Simd<float16_t, N>(b).value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
} \
template <typename T> \
Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
auto out = op(Simd<float16_t, N>(a).value, b.value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
} \
inline Simd<bool, N> operator Op( \
Simd<float16_t, N> a, Simd<float16_t, N> b) { \
auto out = op(a.value, b.value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
}
DEFINE_NEON_COMPARISON(==, vceqq_f16)
DEFINE_NEON_COMPARISON(>=, vcgeq_f16)
DEFINE_NEON_COMPARISON(<=, vcleq_f16)
DEFINE_NEON_COMPARISON(>, vcgtq_f16)
DEFINE_NEON_COMPARISON(<, vcltq_f16)
template <typename T>
Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
return !(a == b);
}
template <typename T>
Simd<bool, N> operator!=(T a, Simd<float16_t, N> b) {
return !(a == b);
}
inline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {
return !(a == b);
}
inline Simd<float16_t, N> operator||(
Simd<float16_t, N> a,
Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
template <typename T>
Simd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
template <typename T>
Simd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
inline Simd<float16_t, N> operator&&(
Simd<float16_t, N> a,
Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <typename T>
Simd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <typename T>
Simd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <>
inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
return v != v;
}
template <>
inline Simd<float16_t, N>
clamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {
return minimum(maximum(v, min), max);
}
template <typename T>
Simd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {
return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);
}
template <typename MaskT>
Simd<float16_t, N>
select(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {
return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
}
// Reductions
inline float16_t max(Simd<float16_t, N> x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t min(Simd<float16_t, N> x) {
float16x4_t y;
y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpmin_f16(y, y);
y = vpmin_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t sum(Simd<float16_t, N> x) {
float16x4_t y;
y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpadd_f16(y, y);
y = vpadd_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t prod(Simd<float16_t, N> x) {
auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
auto out = hx[0];
hx[0] *= hx[1];
hx[0] *= hx[2];
hx[0] *= hx[3];
return hx[0];
}
} // namespace mlx::core::simd

View File

@@ -1,4 +0,0 @@
#pragma once
#include "mlx/backend/common/simd/math.h"
#include "mlx/backend/common/simd/type.h"

View File

@@ -1,7 +0,0 @@
#pragma once
#include "mlx/backend/common/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE
#include "mlx/backend/common/simd/accelerate_simd.h"
#endif

View File

@@ -1,173 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void softmax(const array& in, array& out) {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out);
break;
case float16:
if (precise_) {
softmax<float16_t, float>(in, out);
} else {
softmax<float16_t, float16_t>(in, out);
}
break;
case bfloat16:
if (precise_) {
softmax<bfloat16_t, float>(in, out);
} else {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
break;
}
}
} // namespace mlx::core

View File

@@ -1,426 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = int32_t;
using value_type = T;
using reference = value_type&;
using pointer = value_type*;
// Constructors
StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
// Accessors
reference operator*() const {
return ptr_[0];
}
reference operator[](difference_type idx) const {
return ptr_[idx * stride_];
}
// Comparisons
bool operator==(const StridedIterator& other) const {
return ptr_ == other.ptr_ && stride_ == other.stride_;
}
bool operator!=(const StridedIterator& other) const {
return ptr_ != other.ptr_;
}
bool operator<(const StridedIterator& other) const {
return ptr_ < other.ptr_;
}
bool operator>(const StridedIterator& other) const {
return ptr_ > other.ptr_;
}
bool operator<=(const StridedIterator& other) const {
return ptr_ <= other.ptr_;
}
bool operator>=(const StridedIterator& other) const {
return ptr_ >= other.ptr_;
}
difference_type operator-(const StridedIterator& other) const {
return (ptr_ - other.ptr_) / stride_;
}
// Moving
StridedIterator& operator++() {
ptr_ += stride_;
return *this;
}
StridedIterator& operator--() {
ptr_ -= stride_;
return *this;
}
StridedIterator& operator+=(difference_type diff) {
ptr_ += diff * stride_;
return *this;
}
StridedIterator& operator-=(difference_type diff) {
ptr_ -= diff * stride_;
return *this;
}
StridedIterator operator+(difference_type diff) {
return StridedIterator(ptr_, stride_, diff);
}
StridedIterator operator-(difference_type diff) {
return StridedIterator(ptr_, stride_, -diff);
}
private:
int64_t stride_;
T* ptr_;
};
template <typename T, typename IdxT = uint32_t>
void sort(const array& in, array& out, int axis) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = out.strides()[axis];
auto axis_size = out.shape(axis);
// Perform sorting in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
src_it.step();
}
}
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
// Perform sorting
ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
}
template <typename T, typename IdxT = uint32_t>
void partition(const array& in, array& out, int axis, int kth) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
}
}
template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition
ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
}
} // namespace
void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
case uint8:
return argsort<uint8_t>(in, out, axis_);
case uint16:
return argsort<uint16_t>(in, out, axis_);
case uint32:
return argsort<uint32_t>(in, out, axis_);
case uint64:
return argsort<uint64_t>(in, out, axis_);
case int8:
return argsort<int8_t>(in, out, axis_);
case int16:
return argsort<int16_t>(in, out, axis_);
case int32:
return argsort<int32_t>(in, out, axis_);
case int64:
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
}
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return sort<bool>(in, out, axis_);
case uint8:
return sort<uint8_t>(in, out, axis_);
case uint16:
return sort<uint16_t>(in, out, axis_);
case uint32:
return sort<uint32_t>(in, out, axis_);
case uint64:
return sort<uint64_t>(in, out, axis_);
case int8:
return sort<int8_t>(in, out, axis_);
case int16:
return sort<int16_t>(in, out, axis_);
case int32:
return sort<int32_t>(in, out, axis_);
case int64:
return sort<int64_t>(in, out, axis_);
case float32:
return sort<float>(in, out, axis_);
case float16:
return sort<float16_t>(in, out, axis_);
case bfloat16:
return sort<bfloat16_t>(in, out, axis_);
case complex64:
return sort<complex64_t>(in, out, axis_);
}
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
}
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_);
case uint8:
return partition<uint8_t>(in, out, axis_, kth_);
case uint16:
return partition<uint16_t>(in, out, axis_, kth_);
case uint32:
return partition<uint32_t>(in, out, axis_, kth_);
case uint64:
return partition<uint64_t>(in, out, axis_, kth_);
case int8:
return partition<int8_t>(in, out, axis_, kth_);
case int16:
return partition<int16_t>(in, out, axis_, kth_);
case int32:
return partition<int32_t>(in, out, axis_, kth_);
case int64:
return partition<int64_t>(in, out, axis_, kth_);
case float32:
return partition<float>(in, out, axis_, kth_);
case float16:
return partition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return partition<complex64_t>(in, out, axis_, kth_);
}
}
} // namespace mlx::core

View File

@@ -1,149 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
void svd_impl(const array& a, array& u, array& s, array& vt) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see
// https://math.stackexchange.com/a/30077)
// A = UΣVᵀ
// Aᵀ = VΣUᵀ
// As a result some of the indices/sizes are swapped as noted above.
// Rows and cols of the original matrix in row-major order.
const int M = a.shape(-2);
const int N = a.shape(-1);
const int K = std::min(M, N);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), float32, nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
// Allocate outputs.
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
static constexpr auto job_u = "V";
static constexpr auto job_vt = "V";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
float workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const float ignored_float = 0;
int info;
// Compute workspace size.
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in.data<float>() + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s.data<float>() + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<float>() + N * N * i,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<float>() + M * M * i,
/* ldvt = */ &ldvt,
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32.");
}
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
} // namespace mlx::core

View File

@@ -7,8 +7,6 @@
namespace mlx::core {
namespace {
// TODO: Add support for more combinations of input types.
enum class TernaryOpType {
ScalarScalarScalar,
@@ -16,7 +14,7 @@ enum class TernaryOpType {
General,
};
TernaryOpType
inline TernaryOpType
get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
@@ -33,7 +31,7 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
return topt;
}
void set_ternary_op_output_data(
inline void set_ternary_op_output_data(
const array& a,
const array& b,
const array& c,
@@ -76,152 +74,5 @@ void set_ternary_op_output_data(
break;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
void ternary_op_dims(
const T1* a,
const T2* b,
const T3* c,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& c_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_c = c_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
a,
b,
c,
out,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
axis + 1);
} else {
*out = op(*a, *b, *c);
}
a += stride_a;
b += stride_b;
c += stride_c;
out += stride_out;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& c_strides = strides[2];
const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<T3>();
int ndim = shape.size();
switch (ndim) {
case 1:
ternary_op_dims<T1, T2, T3, U, Op, 1>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
case 2:
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
c_ptr + c_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
c_it.step();
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
} else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
}
} // namespace
} // namespace mlx::core

View File

@@ -1,31 +0,0 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/threefry.h"
namespace mlx::core::random {
std::pair<uint32_t, uint32_t> threefry2x32_hash(
const std::pair<uint32_t, uint32_t>& key,
std::pair<uint32_t, uint32_t> count) {
constexpr static uint32_t rotations[2][4] = {
{13, 15, 26, 6}, {17, 29, 16, 24}};
uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};
count.first += ks[0];
count.second += ks[1];
for (int i = 0; i < 5; ++i) {
for (auto r : rotations[i % 2]) {
count.first += count.second;
count.second = (count.second << r) | (count.second >> (32 - r));
count.second ^= count.first;
}
count.first += ks[(i + 1) % 3];
count.second += ks[(i + 2) % 3] + i + 1;
}
return count;
}
} // namespace mlx::core::random

View File

@@ -1,21 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cstdint>
#include <utility>
namespace mlx::core::random {
/** Applies the Threefry 2x32 hash function.
* This code is based on the Jax counter-based and splittable PRNG
* https://github.com/google/jax/blob/main/docs/jep/263-prng.md
*
* Original Threefry reference:
* http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
*/
std::pair<uint32_t, uint32_t> threefry2x32_hash(
const std::pair<uint32_t, uint32_t>& key,
std::pair<uint32_t, uint32_t> count);
} // namespace mlx::core::random

View File

@@ -1,285 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/unary_ops.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
auto op = detail::Abs{};
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos());
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh());
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin());
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh());
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan());
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cos());
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh());
}
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Exp());
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1());
}
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
break;
}
}
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p());
}
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid());
}
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sin());
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh());
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
} else {
unary_fp(in, out, detail::Sqrt());
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tan());
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh());
}
} // namespace mlx::core

View File

@@ -1,142 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/simd/simd.h"
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
void set_unary_output_data(const array& in, array& out) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
U* dst = out.data<U>();
constexpr int N = simd::max_size<T>;
size_t size = a.data_size();
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a_ptr)));
size -= N;
a_ptr += N;
dst += N;
}
while (size > 0) {
*dst = op(*a_ptr);
size--;
dst++;
a_ptr++;
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {
unary_op(a_ptr, dst, op, shape, stride);
return;
}
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
it.step();
}
}
}
template <typename Op>
void unary(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bool_:
unary_op<bool>(a, out, op);
break;
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
}
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
}
} // namespace
} // namespace mlx::core

View File

@@ -1,108 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <stdint.h>
#include <cmath>
#include <complex>
#include "mlx/backend/common/simd/simd.h"
namespace mlx::core::detail {
using namespace mlx::core::simd;
#define SINGLE() \
template <typename T> \
T operator()(T x) { \
return (*this)(Simd<T, 1>(x)).value; \
}
#define DEFAULT_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<T, N> operator()(Simd<T, N> x) { \
return simd::op(x); \
} \
SINGLE() \
};
DEFAULT_OP(Abs, abs)
DEFAULT_OP(ArcCos, acos)
DEFAULT_OP(ArcCosh, acosh)
DEFAULT_OP(ArcSin, asin)
DEFAULT_OP(ArcSinh, asinh)
DEFAULT_OP(ArcTan, atan)
DEFAULT_OP(ArcTanh, atanh)
DEFAULT_OP(Ceil, ceil)
DEFAULT_OP(Conjugate, conj)
DEFAULT_OP(Cos, cos)
DEFAULT_OP(Cosh, cosh)
DEFAULT_OP(Erf, erf)
DEFAULT_OP(ErfInv, erfinv)
DEFAULT_OP(Exp, exp)
DEFAULT_OP(Expm1, expm1)
DEFAULT_OP(Floor, floor);
DEFAULT_OP(Log, log);
DEFAULT_OP(Log2, log2);
DEFAULT_OP(Log10, log10);
DEFAULT_OP(Log1p, log1p);
DEFAULT_OP(LogicalNot, operator!)
DEFAULT_OP(Negative, operator-)
DEFAULT_OP(Round, rint);
DEFAULT_OP(Sin, sin)
DEFAULT_OP(Sinh, sinh)
DEFAULT_OP(Sqrt, sqrt)
DEFAULT_OP(Rsqrt, rsqrt)
DEFAULT_OP(Tan, tan)
DEFAULT_OP(Tanh, tanh)
struct Imag {
template <int N>
Simd<float, N> operator()(Simd<complex64_t, N> x) {
return simd::imag(x);
}
SINGLE()
};
struct Real {
template <int N>
Simd<float, N> operator()(Simd<complex64_t, N> x) {
return simd::real(x);
}
SINGLE()
};
struct Sigmoid {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return 1.0f / (1.0f + simd::exp(-x));
}
SINGLE()
};
struct Sign {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0};
if constexpr (std::is_unsigned_v<T>) {
return x != z;
} else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else {
return simd::select(
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
}
}
SINGLE()
};
struct Square {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return x * x;
}
SINGLE()
};
} // namespace mlx::core::detail