mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)
* start to cleanup/unify accelerate and common back-ends * more progress * simplify * add half type and allow infs in simd exp * unify softmax + quantized, more dispatches to simd quantized mm * add sin/cos, use simd in vector-scalar ops * faster CPU vectorize quant * faster erf/erfinv
This commit is contained in:
@@ -5,6 +5,18 @@ else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
set(COMPILE_DEPS
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
simd/simd.h
|
||||
simd/base_simd.h
|
||||
simd/math.h
|
||||
simd/type.h
|
||||
unary_ops.h
|
||||
binary_ops.h)
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
@@ -19,13 +31,8 @@ add_custom_command(
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT}
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h)
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
|
||||
${COMPILE_DEPS})
|
||||
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
@@ -60,6 +67,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ void arg_reduce_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -15,69 +15,61 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, U, Op> opsv(op);
|
||||
DefaultVectorScalar<T, U, Op> opvs(op);
|
||||
DefaultVectorVector<T, U, Op> opvv(op);
|
||||
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
comparison_op<bool, bool>(a, b, out, op);
|
||||
binary_op<bool, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
comparison_op<uint8_t, bool>(a, b, out, op);
|
||||
binary_op<uint8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
comparison_op<uint16_t, bool>(a, b, out, op);
|
||||
binary_op<uint16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
comparison_op<uint32_t, bool>(a, b, out, op);
|
||||
binary_op<uint32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
comparison_op<uint64_t, bool>(a, b, out, op);
|
||||
binary_op<uint64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
comparison_op<int8_t, bool>(a, b, out, op);
|
||||
binary_op<int8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
comparison_op<int16_t, bool>(a, b, out, op);
|
||||
binary_op<int16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
comparison_op<int32_t, bool>(a, b, out, op);
|
||||
binary_op<int32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
comparison_op<int64_t, bool>(a, b, out, op);
|
||||
binary_op<int64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
comparison_op<float16_t, bool>(a, b, out, op);
|
||||
binary_op<float16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
comparison_op<float, bool>(a, b, out, op);
|
||||
binary_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
comparison_op<bfloat16_t, bool>(a, b, out, op);
|
||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
comparison_op<complex64_t, bool>(a, b, out, op);
|
||||
binary_op<complex64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add());
|
||||
}
|
||||
|
||||
void DivMod::eval(
|
||||
void DivMod::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
@@ -132,50 +124,68 @@ void DivMod::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide());
|
||||
}
|
||||
|
||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder());
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (equal_nan_) {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
} else {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Equal());
|
||||
comparison_op(a, b, out, detail::Equal());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
}
|
||||
|
||||
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
}
|
||||
|
||||
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
}
|
||||
|
||||
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@@ -196,54 +206,54 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum());
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum());
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply());
|
||||
}
|
||||
|
||||
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
}
|
||||
|
||||
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power());
|
||||
}
|
||||
|
||||
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
@@ -307,7 +317,7 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
@@ -122,16 +124,22 @@ void set_binary_op_output_data(
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
template <typename Op>
|
||||
struct VectorScalar {
|
||||
Op op;
|
||||
|
||||
DefaultVectorScalar(Op op_) : op(op_) {}
|
||||
VectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||
dst += N;
|
||||
a += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
@@ -140,14 +148,22 @@ struct DefaultVectorScalar {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultScalarVector {
|
||||
template <typename Op>
|
||||
struct ScalarVector {
|
||||
Op op;
|
||||
|
||||
DefaultScalarVector(Op op_) : op(op_) {}
|
||||
ScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
@@ -156,13 +172,22 @@ struct DefaultScalarVector {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorVector {
|
||||
template <typename Op>
|
||||
struct VectorVector {
|
||||
Op op;
|
||||
|
||||
DefaultVectorVector(Op op_) : op(op_) {}
|
||||
VectorVector(Op op_) : op(op_) {}
|
||||
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
constexpr int N = simd::max_size<T>;
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||
dst += N;
|
||||
a += N;
|
||||
b += N;
|
||||
size -= N;
|
||||
}
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
@@ -277,21 +302,8 @@ void binary_op_dispatch_dims(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
@@ -303,19 +315,19 @@ void binary_op(
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -376,15 +388,39 @@ void binary_op(
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
VectorScalar{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
ScalarVector{op},
|
||||
dim,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
@@ -393,134 +429,52 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
binary_op<T, T>(a, b, out, op);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(const array& a, const array& b, array& out, Ops... ops) {
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, ops...);
|
||||
binary_op<bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, ops...);
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, ops...);
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, ops...);
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, ops...);
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, ops...);
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, ops...);
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, ops...);
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, ops...);
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, ops...);
|
||||
binary_op<float16_t>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, ops...);
|
||||
binary_op<float>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, ops...);
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, ops...);
|
||||
binary_op<complex64_t>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
98
mlx/backend/common/binary_ops.h
Normal file
98
mlx/backend/common/binary_ops.h
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
#define BINARY_SINGLE() \
|
||||
template <typename T> \
|
||||
T operator()(T x, T y) { \
|
||||
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
|
||||
}
|
||||
|
||||
#define DEFAULT_BINARY_OP(Op, op) \
|
||||
struct Op { \
|
||||
template <int N, typename T> \
|
||||
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
|
||||
return op(x, y); \
|
||||
} \
|
||||
BINARY_SINGLE() \
|
||||
};
|
||||
|
||||
DEFAULT_BINARY_OP(Add, operator+)
|
||||
DEFAULT_BINARY_OP(ArcTan2, atan2)
|
||||
DEFAULT_BINARY_OP(Divide, operator/)
|
||||
DEFAULT_BINARY_OP(Multiply, operator*)
|
||||
DEFAULT_BINARY_OP(Subtract, operator-)
|
||||
DEFAULT_BINARY_OP(LogicalAnd, operator&&)
|
||||
DEFAULT_BINARY_OP(LogicalOr, operator||)
|
||||
DEFAULT_BINARY_OP(BitwiseAnd, operator&)
|
||||
DEFAULT_BINARY_OP(BitwiseOr, operator|)
|
||||
DEFAULT_BINARY_OP(BitwiseXor, operator^)
|
||||
DEFAULT_BINARY_OP(LeftShift, operator<<)
|
||||
DEFAULT_BINARY_OP(RightShift, operator>>)
|
||||
DEFAULT_BINARY_OP(Remainder, remainder)
|
||||
DEFAULT_BINARY_OP(Maximum, maximum)
|
||||
DEFAULT_BINARY_OP(Minimum, minimum)
|
||||
DEFAULT_BINARY_OP(Power, pow)
|
||||
|
||||
#define DEFAULT_BOOL_OP(Op, op) \
|
||||
struct Op { \
|
||||
template <int N, typename T> \
|
||||
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
|
||||
return op(x, y); \
|
||||
} \
|
||||
template <typename T> \
|
||||
bool operator()(T x, T y) { \
|
||||
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
|
||||
} \
|
||||
};
|
||||
|
||||
DEFAULT_BOOL_OP(Equal, operator==)
|
||||
DEFAULT_BOOL_OP(Greater, operator>)
|
||||
DEFAULT_BOOL_OP(GreaterEqual, operator>=)
|
||||
DEFAULT_BOOL_OP(Less, operator<)
|
||||
DEFAULT_BOOL_OP(LessEqual, operator<=)
|
||||
DEFAULT_BOOL_OP(NotEqual, operator!=)
|
||||
|
||||
struct NaNEqual {
|
||||
template <int N, typename T>
|
||||
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) {
|
||||
return x == y || (isnan(x) && isnan(y));
|
||||
}
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) {
|
||||
auto maxval = maximum(x, y);
|
||||
auto minval = minimum(x, y);
|
||||
auto mask = minval == -inf || maxval == inf;
|
||||
auto out = maxval + log1p(exp(minval - maxval));
|
||||
return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));
|
||||
}
|
||||
BINARY_SINGLE()
|
||||
};
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))
|
||||
.value;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<bool, N> condition, Simd<T, N> x, Simd<T, N> y) {
|
||||
return select(condition, x, y);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
@@ -64,7 +64,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
// clang-format off
|
||||
#include "mlx/types/half_types.h"
|
||||
#include "mlx/types/complex.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/unary_ops.h"
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
// clang-format on
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -23,6 +24,7 @@ template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
size_t size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
|
||||
@@ -21,98 +21,9 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Abs)
|
||||
DEFAULT(Add)
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArcCos)
|
||||
DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BroadcastAxes)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(Multiply)
|
||||
DEFAULT(Negative)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -45,7 +45,9 @@ void ssyevd(
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
void Eigh::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval(const std::vector<array>& inputs, array& out) {
|
||||
void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
std::vector<std::ptrdiff_t> strides_in(
|
||||
in.strides().begin(), in.strides().end());
|
||||
|
||||
@@ -82,7 +82,7 @@ void hadamard(array& out, int n, int m, float scale) {
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -104,4 +104,4 @@ void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -162,7 +162,7 @@ void dispatch_gather(
|
||||
}
|
||||
}
|
||||
|
||||
void Gather::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
@@ -337,7 +337,7 @@ void dispatch_scatter(
|
||||
}
|
||||
}
|
||||
|
||||
void Scatter::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
|
||||
@@ -110,7 +110,7 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -51,11 +48,4 @@ void load(
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -53,7 +53,7 @@ inline void mask_matrix(
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
@@ -210,7 +210,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
|
||||
@@ -1,680 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
namespace {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
} // namespace
|
||||
|
||||
typedef union {
|
||||
int i;
|
||||
float f;
|
||||
} IntOrFloat;
|
||||
|
||||
inline float fast_exp(float x) {
|
||||
if (x == -std::numeric_limits<float>::infinity()) {
|
||||
return 0.0f;
|
||||
} else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
x *= 1.442695; // multiply with log_2(e)
|
||||
float ipart, fpart;
|
||||
IntOrFloat epart;
|
||||
x = std::max(-80.f, std::min(x, 80.f));
|
||||
ipart = std::floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = x * fpart + 1.339887440266574e-3f;
|
||||
x = x * fpart + 9.618437357674640e-3f;
|
||||
x = x * fpart + 5.550332471162809e-2f;
|
||||
x = x * fpart + 2.402264791363012e-1f;
|
||||
x = x * fpart + 6.931472028550421e-1f;
|
||||
x = x * fpart + 1.000000000000000f;
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
epart.i = (int(ipart) + 127) << 23;
|
||||
|
||||
return epart.f * x;
|
||||
}
|
||||
|
||||
inline float fast_erf(float a) {
|
||||
float r, s, t, u;
|
||||
t = std::abs(a);
|
||||
s = a * a;
|
||||
if (t > 0.927734375f) {
|
||||
// maximum error 0.99527 ulp
|
||||
r = std::fma(
|
||||
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
||||
u = std::fma(
|
||||
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
||||
r = std::fma(r, s, u);
|
||||
r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
||||
r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
||||
r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
||||
r = std::fma(r, t, -t);
|
||||
// TODO, replace with expm1 when implemented
|
||||
r = 1.0f - std::exp(r);
|
||||
r = std::copysign(r, a);
|
||||
} else {
|
||||
// maximum error 0.98929 ulp
|
||||
r = -5.96761703e-4f; // -0x1.38e000p-11
|
||||
r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
||||
r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
||||
r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
||||
r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
||||
r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
||||
r = std::fma(r, a, a);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
inline float fast_erfinv(float a) {
|
||||
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||
t = std::log(t);
|
||||
float p;
|
||||
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::ceil(x);
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::conj(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erf(static_cast<float>(x)));
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return fast_exp(x);
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::floor(x);
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::imag(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log2(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log10(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::real(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::rint(x);
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {std::rint(x.real()), std::rint(x.imag())};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto one = static_cast<decltype(x)>(1.0);
|
||||
return one / (one + fast_exp(-x));
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return x == complex64_t(0) ? x : x / std::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = numerator % denominator;
|
||||
if (r != 0 && (r < 0 != denominator < 0))
|
||||
r += denominator;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = std::fmod(numerator, denominator);
|
||||
if (r != 0 && (r < 0 != denominator < 0)) {
|
||||
r += denominator;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
// isnan always returns false for integers, and MSVC refuses to compile.
|
||||
return x == y;
|
||||
} else {
|
||||
return x == y || (std::isnan(x) && std::isnan(y));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return (x > y) ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return (x > y) ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto maxval = Maximum()(x, y);
|
||||
auto minval = Minimum()(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: static_cast<decltype(x)>(
|
||||
maxval + std::log1p(fast_exp(minval - maxval)));
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return std::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
}
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
@@ -9,10 +9,9 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -58,112 +57,64 @@ int64_t compute_dynamic_offset(
|
||||
}
|
||||
}
|
||||
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), unsignedinteger)) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Abs());
|
||||
}
|
||||
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Broadcast::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void BroadcastAxes::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Copy::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void CustomTransforms::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
void Depends::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
void ExpandDims::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Split::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
void Squeeze::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void StopGradient::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
arange(inputs, out, start_, step_);
|
||||
}
|
||||
|
||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arccos] Cannot compute inverse cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arcsin] Cannot compute inverse sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan] Cannot compute inverse tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
|
||||
" array with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval(const std::vector<array>& inputs, array& out) {
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
for (auto& p : inputs) {
|
||||
@@ -187,17 +138,6 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_fp(in, out, detail::Conjugate());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -209,94 +149,6 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[cos] Cannot compute cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[cosh] Cannot compute hyperbolic cosine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::Erf());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Exp());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[expm1] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
@@ -305,18 +157,7 @@ void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
@@ -331,57 +172,14 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, detail::Log2());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, detail::Log10());
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::LogicalNot());
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Negative());
|
||||
}
|
||||
|
||||
void Pad::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Inputs must be base input array and scalar val array
|
||||
assert(inputs.size() == 2);
|
||||
auto& in = inputs[0];
|
||||
@@ -412,7 +210,7 @@ void Pad::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
@@ -460,71 +258,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sigmoid] Cannot sigmoid of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Sign());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sin] Cannot compute sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[sinh] Cannot compute hyperbolic sine of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
@@ -596,7 +333,7 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
@@ -632,46 +369,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Square());
|
||||
}
|
||||
|
||||
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, detail::Rsqrt());
|
||||
} else {
|
||||
unary_fp(in, out, detail::Sqrt());
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tan] Cannot compute tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tanh] Cannot compute hyperbolic tangent of elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -149,7 +149,9 @@ void qrf_impl(const array& a, array& q, array& r) {
|
||||
allocator::free(tau);
|
||||
}
|
||||
|
||||
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
void QRF::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[QRF::eval] only supports float32.");
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -151,6 +151,78 @@ void _qmm_t(
|
||||
}
|
||||
}
|
||||
|
||||
template <int bits, int S>
|
||||
simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
simd::Simd<uint32_t, S> wi;
|
||||
if constexpr (bits == 4 && S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
wi = simd::Simd<uint32_t, S>(*w);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & bitmask;
|
||||
} else if constexpr (bits == 8 && S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
auto l = simd::Simd<uint32_t, 4>(*w++);
|
||||
auto r = simd::Simd<uint32_t, 4>(*w);
|
||||
wi = simd::Simd<uint32_t, S>(l, r);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & bitmask;
|
||||
} else {
|
||||
// Appease compiler.. but should never get here
|
||||
throw std::runtime_error("Unsupported combination for simd qmm.");
|
||||
}
|
||||
return wi;
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_t_simd(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const T* scales,
|
||||
const T* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
constexpr int S = simd::max_size<T>;
|
||||
static_assert(
|
||||
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
|
||||
constexpr int packs_per_simd = S / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const T* scales_local = scales;
|
||||
const T* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
simd::Simd<float, S> acc(0);
|
||||
auto x_local = x;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
|
||||
auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));
|
||||
w_local += packs_per_simd;
|
||||
wf = wf * scale;
|
||||
wf = wf + bias;
|
||||
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
|
||||
acc = acc + x_simd * wf;
|
||||
x_local += S;
|
||||
}
|
||||
}
|
||||
|
||||
*result = T(simd::sum(acc));
|
||||
result++;
|
||||
}
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_dispatch_transpose(
|
||||
T* result,
|
||||
@@ -163,9 +235,14 @@ void _qmm_dispatch_transpose(
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
// the simd size must be a multiple of the number of elements per word
|
||||
if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {
|
||||
_qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
_qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
} else {
|
||||
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
_qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,13 +326,13 @@ void _qmm_dispatch(
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
|
||||
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
|
||||
int batch_size = x.size() / (K * M);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
@@ -384,7 +461,7 @@ void _bs_qmm_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
@@ -411,7 +488,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/binary_ops.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -61,7 +62,7 @@ void select_op(
|
||||
|
||||
} // namespace
|
||||
|
||||
void Select::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
const auto& condition = inputs[0];
|
||||
const auto& a = inputs[1];
|
||||
|
||||
56
mlx/backend/common/simd/accelerate_fp16_simd.h
Normal file
56
mlx/backend/common/simd/accelerate_fp16_simd.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
#if MLX_SIMD_LIBRARY_VERSION < 6
|
||||
#include "mlx/backend/common/simd/neon_fp16_simd.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
#if MLX_SIMD_LIBRARY_VERSION >= 6
|
||||
constexpr int N = 8;
|
||||
template <int N>
|
||||
struct ScalarT<float16_t, N> {
|
||||
using v = _Float16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr int max_size<float16_t> = N;
|
||||
|
||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||
template <> \
|
||||
inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \
|
||||
Simd<float, N> in = v; \
|
||||
return op(in); \
|
||||
}
|
||||
|
||||
SIMD_FP16_DEFAULT_UNARY(acos)
|
||||
SIMD_FP16_DEFAULT_UNARY(acosh)
|
||||
SIMD_FP16_DEFAULT_UNARY(asin)
|
||||
SIMD_FP16_DEFAULT_UNARY(asinh)
|
||||
SIMD_FP16_DEFAULT_UNARY(atan)
|
||||
SIMD_FP16_DEFAULT_UNARY(atanh)
|
||||
SIMD_FP16_DEFAULT_UNARY(cosh)
|
||||
SIMD_FP16_DEFAULT_UNARY(expm1)
|
||||
SIMD_FP16_DEFAULT_UNARY(log)
|
||||
SIMD_FP16_DEFAULT_UNARY(log2)
|
||||
SIMD_FP16_DEFAULT_UNARY(log10)
|
||||
SIMD_FP16_DEFAULT_UNARY(log1p)
|
||||
SIMD_FP16_DEFAULT_UNARY(sinh)
|
||||
SIMD_FP16_DEFAULT_UNARY(tan)
|
||||
SIMD_FP16_DEFAULT_UNARY(tanh)
|
||||
|
||||
#define SIMD_FP16_DEFAULT_BINARY(op) \
|
||||
template <> \
|
||||
inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \
|
||||
Simd<float, N> a = x; \
|
||||
Simd<float, N> b = y; \
|
||||
return op(a, b); \
|
||||
}
|
||||
SIMD_FP16_DEFAULT_BINARY(atan2)
|
||||
SIMD_FP16_DEFAULT_BINARY(remainder)
|
||||
SIMD_FP16_DEFAULT_BINARY(pow)
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
291
mlx/backend/common/simd/accelerate_simd.h
Normal file
291
mlx/backend/common/simd/accelerate_simd.h
Normal file
@@ -0,0 +1,291 @@
|
||||
#pragma once
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
// There seems to be a bug in sims/base.h
|
||||
// __XROS_2_0 is not defined, the expression evaluates
|
||||
// to true instead of false setting the SIMD library
|
||||
// higher than it should be even on macOS < 15
|
||||
#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
|
||||
__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
|
||||
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
|
||||
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
|
||||
__TV_OS_VERSION_MIN_REQUIRED >= 180000
|
||||
#define MLX_SIMD_LIBRARY_VERSION 6
|
||||
#else
|
||||
#define MLX_SIMD_LIBRARY_VERSION 5
|
||||
#endif
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
// Apple simd namespace
|
||||
namespace asd = ::simd;
|
||||
|
||||
// This indirection is needed to remap certain types to ones that accelerate
|
||||
// SIMD can handle
|
||||
template <typename T, int N>
|
||||
struct ScalarT {
|
||||
using v = T;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<bool, N> {
|
||||
using v = char;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<int8_t, N> {
|
||||
using v = char;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<uint64_t, N> {
|
||||
using v = unsigned long;
|
||||
};
|
||||
template <int N>
|
||||
struct ScalarT<int64_t, N> {
|
||||
using v = long;
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct Simd {
|
||||
static constexpr int size = N;
|
||||
using scalar_t = typename ScalarT<T, N>::v;
|
||||
|
||||
Simd<T, N>() {}
|
||||
|
||||
template <typename U>
|
||||
Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
|
||||
|
||||
template <typename U>
|
||||
Simd<T, N>(U v) : value(v){};
|
||||
|
||||
Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {
|
||||
value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
|
||||
x.value, y.value);
|
||||
};
|
||||
|
||||
T operator[](int idx) const {
|
||||
return reinterpret_cast<const T*>(&value)[idx];
|
||||
}
|
||||
|
||||
T& operator[](int idx) {
|
||||
return reinterpret_cast<T*>(&value)[idx];
|
||||
}
|
||||
|
||||
typename asd::Vector<scalar_t, N>::packed_t value;
|
||||
};
|
||||
|
||||
// Values chosen based on benchmarks on M3 Max
|
||||
// TODO: consider choosing these more optimally
|
||||
template <>
|
||||
static constexpr int max_size<int8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<int> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<int64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<uint8_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint16_t> = 16;
|
||||
template <>
|
||||
static constexpr int max_size<uint32_t> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<uint64_t> = 4;
|
||||
template <>
|
||||
static constexpr int max_size<float> = 8;
|
||||
template <>
|
||||
static constexpr int max_size<double> = 4;
|
||||
|
||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||
template <typename T, int N> \
|
||||
Simd<T, N> name(Simd<T, N> v) { \
|
||||
return op(v.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_UNARY(abs, asd::abs)
|
||||
SIMD_DEFAULT_UNARY(floor, asd::floor)
|
||||
SIMD_DEFAULT_UNARY(acos, asd::acos)
|
||||
SIMD_DEFAULT_UNARY(acosh, asd::acosh)
|
||||
SIMD_DEFAULT_UNARY(asin, asd::asin)
|
||||
SIMD_DEFAULT_UNARY(asinh, asd::asinh)
|
||||
SIMD_DEFAULT_UNARY(atan, asd::atan)
|
||||
SIMD_DEFAULT_UNARY(atanh, asd::atanh)
|
||||
SIMD_DEFAULT_UNARY(ceil, asd::ceil)
|
||||
SIMD_DEFAULT_UNARY(cosh, asd::cosh)
|
||||
SIMD_DEFAULT_UNARY(expm1, asd::expm1)
|
||||
SIMD_DEFAULT_UNARY(log, asd::log)
|
||||
SIMD_DEFAULT_UNARY(log2, asd::log2)
|
||||
SIMD_DEFAULT_UNARY(log10, asd::log10)
|
||||
SIMD_DEFAULT_UNARY(log1p, asd::log1p)
|
||||
SIMD_DEFAULT_UNARY(rint, asd::rint)
|
||||
SIMD_DEFAULT_UNARY(sinh, asd::sinh)
|
||||
SIMD_DEFAULT_UNARY(sqrt, asd::sqrt)
|
||||
SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)
|
||||
SIMD_DEFAULT_UNARY(recip, asd::recip)
|
||||
SIMD_DEFAULT_UNARY(tan, asd::tan)
|
||||
SIMD_DEFAULT_UNARY(tanh, asd::tanh)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> operator-(Simd<T, N> v) {
|
||||
return -v.value;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> isnan(Simd<T, N> v) {
|
||||
return asd::convert<char>(v.value != v.value);
|
||||
}
|
||||
|
||||
// No simd_boolN in accelerate, use int8_t instead
|
||||
template <typename T, int N>
|
||||
Simd<bool, N> operator!(Simd<T, N> v) {
|
||||
return asd::convert<char>(!v.value);
|
||||
}
|
||||
|
||||
#define SIMD_DEFAULT_BINARY(OP) \
|
||||
template <typename T, typename U, int N> \
|
||||
Simd<T, N> operator OP(Simd<T, N> x, U y) { \
|
||||
return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
|
||||
} \
|
||||
template <typename T1, typename T2, int N> \
|
||||
Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
|
||||
return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
|
||||
} \
|
||||
template <typename T1, typename T2, int N> \
|
||||
Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
|
||||
return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_BINARY(+)
|
||||
SIMD_DEFAULT_BINARY(-)
|
||||
SIMD_DEFAULT_BINARY(/)
|
||||
SIMD_DEFAULT_BINARY(*)
|
||||
SIMD_DEFAULT_BINARY(<<)
|
||||
SIMD_DEFAULT_BINARY(>>)
|
||||
SIMD_DEFAULT_BINARY(|)
|
||||
SIMD_DEFAULT_BINARY(^)
|
||||
SIMD_DEFAULT_BINARY(&)
|
||||
SIMD_DEFAULT_BINARY(&&)
|
||||
SIMD_DEFAULT_BINARY(||)
|
||||
|
||||
#define SIMD_DEFAULT_COMPARISONS(OP) \
|
||||
template <int N, typename T, typename U> \
|
||||
Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
|
||||
return asd::convert<char>(a.value OP b); \
|
||||
} \
|
||||
template <int N, typename T, typename U> \
|
||||
Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
|
||||
return asd::convert<char>(a OP b.value); \
|
||||
} \
|
||||
template <int N, typename T1, typename T2> \
|
||||
Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
|
||||
return asd::convert<char>(a.value OP b.value); \
|
||||
}
|
||||
|
||||
SIMD_DEFAULT_COMPARISONS(>)
|
||||
SIMD_DEFAULT_COMPARISONS(<)
|
||||
SIMD_DEFAULT_COMPARISONS(>=)
|
||||
SIMD_DEFAULT_COMPARISONS(<=)
|
||||
SIMD_DEFAULT_COMPARISONS(==)
|
||||
SIMD_DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
|
||||
return asd::atan2(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::max(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::min(a.value, b.value);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
|
||||
Simd<T, N> r;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
r = asd::remainder(a.value, b.value);
|
||||
} else {
|
||||
r = a - b * (a / b);
|
||||
}
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
auto mask = r != 0 && (r < 0 != b < 0);
|
||||
r = select(mask, r + b, r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename MaskT, typename T1, typename T2, int N>
|
||||
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
||||
if constexpr (sizeof(T1) == 1) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 2) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 4) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
|
||||
} else {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
return asd::pow(base.value, exp.value);
|
||||
} else {
|
||||
Simd<T, N> res = 1;
|
||||
while (any(exp)) {
|
||||
res = select(exp & 1, res * base, res);
|
||||
base = select(exp, base * base, base);
|
||||
exp = exp >> 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {
|
||||
return asd::clamp(v.value, min.value, max.value);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N>
|
||||
Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
|
||||
return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
|
||||
template <typename T, int N>
|
||||
bool any(Simd<T, N> x) {
|
||||
return asd::any(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T sum(Simd<T, N> x) {
|
||||
return asd::reduce_add(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T max(Simd<T, N> x) {
|
||||
return asd::reduce_max(x.value);
|
||||
}
|
||||
template <typename T, int N>
|
||||
T min(Simd<T, N> x) {
|
||||
return asd::reduce_min(x.value);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include "mlx/backend/common/simd/accelerate_fp16_simd.h"
|
||||
#endif
|
||||
252
mlx/backend/common/simd/base_simd.h
Normal file
252
mlx/backend/common/simd/base_simd.h
Normal file
@@ -0,0 +1,252 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
namespace mlx::core::simd {
|
||||
template <typename T, int N>
|
||||
struct Simd;
|
||||
|
||||
template <typename T>
|
||||
static constexpr int max_size = 1;
|
||||
|
||||
template <typename T>
|
||||
struct Simd<T, 1> {
|
||||
static constexpr int size = 1;
|
||||
T value;
|
||||
Simd() {}
|
||||
template <typename U>
|
||||
Simd(Simd<U, 1> v) : value(v.value) {}
|
||||
template <typename U>
|
||||
Simd(U v) : value(v) {}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> load(const T* x) {
|
||||
return *(Simd<T, N>*)x;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
void store(T* dst, Simd<T, N> x) {
|
||||
// Maintain invariant that bool is either 0 or 1 as
|
||||
// simd comparison ops set all bits in the result to 1
|
||||
if constexpr (std::is_same_v<T, bool> && N > 1) {
|
||||
x = x & 1;
|
||||
}
|
||||
*(Simd<T, N>*)dst = x;
|
||||
}
|
||||
|
||||
template <typename, typename = void>
|
||||
constexpr bool is_complex = false;
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =
|
||||
true;
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> rint(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return Simd<T, 1>{
|
||||
T{std::rint(in.value.real()), std::rint(in.value.imag())}};
|
||||
} else {
|
||||
return Simd<T, 1>{std::rint(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> rsqrt(Simd<T, 1> in) {
|
||||
return T(1.0) / sqrt(in);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> recip(Simd<T, 1> in) {
|
||||
return T(1.0) / in;
|
||||
}
|
||||
|
||||
#define DEFAULT_UNARY(name, op) \
|
||||
template <typename T> \
|
||||
Simd<T, 1> name(Simd<T, 1> in) { \
|
||||
return op(in.value); \
|
||||
}
|
||||
|
||||
DEFAULT_UNARY(operator-, std::negate{})
|
||||
DEFAULT_UNARY(operator!, std::logical_not{})
|
||||
DEFAULT_UNARY(abs, std::abs)
|
||||
DEFAULT_UNARY(acos, std::acos)
|
||||
DEFAULT_UNARY(acosh, std::acosh)
|
||||
DEFAULT_UNARY(asin, std::asin)
|
||||
DEFAULT_UNARY(asinh, std::asinh)
|
||||
DEFAULT_UNARY(atan, std::atan)
|
||||
DEFAULT_UNARY(atanh, std::atanh)
|
||||
DEFAULT_UNARY(ceil, std::ceil)
|
||||
DEFAULT_UNARY(conj, std::conj)
|
||||
DEFAULT_UNARY(cosh, std::cosh)
|
||||
DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log2, std::log2)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
|
||||
return std::real(in.value);
|
||||
}
|
||||
template <typename T>
|
||||
auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
|
||||
return std::imag(in.value);
|
||||
}
|
||||
template <typename T>
|
||||
Simd<bool, 1> isnan(Simd<T, 1> in) {
|
||||
return std::isnan(in.value);
|
||||
}
|
||||
|
||||
#define DEFAULT_BINARY(OP) \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
|
||||
->Simd<decltype(a.value OP b.value), 1> { \
|
||||
return a.value OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
|
||||
return a OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
|
||||
return a.value OP b; \
|
||||
}
|
||||
|
||||
DEFAULT_BINARY(+)
|
||||
DEFAULT_BINARY(-)
|
||||
DEFAULT_BINARY(*)
|
||||
DEFAULT_BINARY(/)
|
||||
DEFAULT_BINARY(<<)
|
||||
DEFAULT_BINARY(>>)
|
||||
DEFAULT_BINARY(|)
|
||||
DEFAULT_BINARY(^)
|
||||
DEFAULT_BINARY(&)
|
||||
DEFAULT_BINARY(&&)
|
||||
DEFAULT_BINARY(||)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
T r;
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
r = a % b;
|
||||
} else {
|
||||
r = std::remainder(a, b);
|
||||
}
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
if (r != 0 && (r < 0 != b < 0)) {
|
||||
r += b;
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {
|
||||
T a = a_.value;
|
||||
T b = b_.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
if (std::isnan(a)) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {
|
||||
T base = a.value;
|
||||
T exp = b.value;
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
return std::pow(base, exp);
|
||||
} else {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {
|
||||
return std::atan2(a.value, b.value);
|
||||
}
|
||||
|
||||
#define DEFAULT_COMPARISONS(OP) \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
|
||||
return a.value OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
|
||||
return a OP b.value; \
|
||||
} \
|
||||
template <typename T1, typename T2> \
|
||||
Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
|
||||
return a.value OP b; \
|
||||
}
|
||||
|
||||
DEFAULT_COMPARISONS(>)
|
||||
DEFAULT_COMPARISONS(<)
|
||||
DEFAULT_COMPARISONS(>=)
|
||||
DEFAULT_COMPARISONS(<=)
|
||||
DEFAULT_COMPARISONS(==)
|
||||
DEFAULT_COMPARISONS(!=)
|
||||
|
||||
template <typename MaskT, typename T>
|
||||
Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
|
||||
return mask.value ? x.value : y.value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {
|
||||
return std::clamp(v.value, min.value, max.value);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
|
||||
return std::fma(x.value, y.value, Simd<T, 1>(z).value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
#define DEFAULT_REDUCTION(name, type) \
|
||||
template <typename T> \
|
||||
type name(Simd<T, 1> x) { \
|
||||
return x.value; \
|
||||
}
|
||||
|
||||
DEFAULT_REDUCTION(max, T)
|
||||
DEFAULT_REDUCTION(min, T)
|
||||
DEFAULT_REDUCTION(sum, T)
|
||||
DEFAULT_REDUCTION(any, bool)
|
||||
DEFAULT_REDUCTION(all, bool)
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
193
mlx/backend/common/simd/math.h
Normal file
193
mlx/backend/common/simd/math.h
Normal file
@@ -0,0 +1,193 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/type.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
template <typename T, int N>
|
||||
Simd<T, N> exp(Simd<T, N> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return Simd<T, 1>{std::exp(in.value)};
|
||||
} else {
|
||||
Simd<float, N> x_init = in;
|
||||
auto x = x_init * 1.442695f; // multiply with log_2(e)
|
||||
Simd<float, N> ipart, fpart;
|
||||
ipart = floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = fma(x, fpart, 1.339887440266574e-3f);
|
||||
x = fma(x, fpart, 9.618437357674640e-3f);
|
||||
x = fma(x, fpart, 5.550332471162809e-2f);
|
||||
x = fma(x, fpart, 2.402264791363012e-1f);
|
||||
x = fma(x, fpart, 6.931472028550421e-1f);
|
||||
x = fma(x, fpart, 1.000000000000000f);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;
|
||||
|
||||
// Deal with NaN and Inf
|
||||
auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);
|
||||
result = select(x_init > 88.0f, Simd<float, N>(inf), result);
|
||||
result = select(x_init < -88.0f, Simd<float, N>(0), result);
|
||||
return Simd<T, N>(result);
|
||||
}
|
||||
}
|
||||
|
||||
/* Implementation from:
|
||||
* https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357
|
||||
* which originally came from the Cephes math library.
|
||||
*/
|
||||
template <bool Sine, typename T, int N>
|
||||
Simd<T, N> sincos(Simd<T, N> in) {
|
||||
auto sign_mask_sin = in < 0;
|
||||
in = abs(in);
|
||||
Simd<float, N> x = in;
|
||||
|
||||
// scale by 4/Pi
|
||||
auto y = x * 1.27323954473516f;
|
||||
|
||||
// store the integer part of y in mm0
|
||||
Simd<uint32_t, N> emm2 = y;
|
||||
|
||||
// j=(j+1) & (~1) (see the cephes sources)
|
||||
emm2 = emm2 + 1;
|
||||
emm2 = emm2 & ~1;
|
||||
|
||||
y = emm2;
|
||||
|
||||
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
|
||||
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
|
||||
auto poly_mask = (emm2 & 2) != 0;
|
||||
|
||||
// The magic pass: "Extended precision modular arithmetic"
|
||||
// x = ((x - y * DP1) - y * DP2) - y * DP3
|
||||
x = fma(y, Simd<float, N>(-0.78515625f), x);
|
||||
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
|
||||
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
|
||||
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
|
||||
|
||||
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
|
||||
// and the second polynom (Pi/4 <= x <= 0) in y2
|
||||
auto z = x * x;
|
||||
|
||||
auto y1 =
|
||||
fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);
|
||||
auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);
|
||||
y1 = fma(y1, z, 4.166664568298827e-2f);
|
||||
y2 = fma(y2, z, -1.6666654611e-1f);
|
||||
y1 = y1 * z;
|
||||
y2 = y2 * z;
|
||||
y1 = y1 * z;
|
||||
y2 = fma(x, y2, x);
|
||||
y1 = fma(z, Simd<float, N>(-0.5f), y1);
|
||||
y1 = y1 + 1.0f;
|
||||
|
||||
if constexpr (Sine) {
|
||||
auto ys = select(poly_mask, y1, y2);
|
||||
return select(sign_mask_sin, -ys, ys);
|
||||
} else {
|
||||
auto yc = select(poly_mask, y2, y1);
|
||||
return select(sign_mask_cos, yc, -yc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> sin(Simd<T, N> x) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return std::sin(x.value);
|
||||
} else {
|
||||
return sincos<true>(x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> cos(Simd<T, N> x) {
|
||||
if constexpr (is_complex<T>) {
|
||||
return std::cos(x.value);
|
||||
} else {
|
||||
return sincos<false>(x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> erf(Simd<T, N> x) {
|
||||
// https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175
|
||||
Simd<float, N> v = x;
|
||||
auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));
|
||||
auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);
|
||||
r = fma(r, t, 1.421413741f);
|
||||
r = fma(r, t, -0.284496736f);
|
||||
r = fma(r, t, 0.254829592f);
|
||||
auto e = -exp(-v * v);
|
||||
auto result = Simd<T, N>(fma(e * t, r, 1.0f));
|
||||
return select(x > 0, result, -result);
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> erfinv(Simd<T, N> a_) {
|
||||
Simd<float, N> a = a_;
|
||||
auto t = fma(a, 0.0f - a, 1.0f);
|
||||
t = log(t);
|
||||
auto lhs = [](auto t) {
|
||||
Simd<float, N> p;
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
};
|
||||
auto rhs = [](auto t) {
|
||||
Simd<float, N> p;
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
};
|
||||
auto thresh = 6.125f;
|
||||
// Compute both branches and select if N > 1
|
||||
if constexpr (N == 1) {
|
||||
if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793
|
||||
return a * lhs(t);
|
||||
} else { // maximum ulp error = 2.35002
|
||||
return a * rhs(t);
|
||||
}
|
||||
} else {
|
||||
return a * select(t > thresh, lhs(t), rhs(t));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
204
mlx/backend/common/simd/neon_fp16_simd.h
Normal file
204
mlx/backend/common/simd/neon_fp16_simd.h
Normal file
@@ -0,0 +1,204 @@
|
||||
#pragma once
|
||||
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
namespace mlx::core::simd {
|
||||
|
||||
constexpr int N = 8;
|
||||
|
||||
template <>
|
||||
struct Simd<float16_t, N> {
|
||||
static constexpr int size = N;
|
||||
using scalar_t = float16_t;
|
||||
|
||||
Simd<float16_t, N>() {}
|
||||
|
||||
template <typename U>
|
||||
Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
|
||||
|
||||
Simd<float16_t, N>(float16x8_t v) : value(v){};
|
||||
|
||||
Simd<float16_t, N>(Simd<float, N> other) {
|
||||
auto f32x4_a = *(float32x4_t*)(&other);
|
||||
auto f32x4_b = *((float32x4_t*)(&other) + 1);
|
||||
value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
|
||||
};
|
||||
|
||||
Simd<float16_t, N>(Simd<uint16_t, N> other) {
|
||||
value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
|
||||
};
|
||||
|
||||
operator Simd<int16_t, N>() {
|
||||
auto v = vcvtq_s16_f16(value);
|
||||
return load<int16_t, N>((int16_t*)&v);
|
||||
};
|
||||
|
||||
operator Simd<float, N>() {
|
||||
float32x4x2_t v;
|
||||
v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
|
||||
v.val[1] = vcvt_high_f32_f16(value);
|
||||
return load<float, N>((float*)&v);
|
||||
}
|
||||
float16_t operator[](int idx) const {
|
||||
return reinterpret_cast<const float16_t*>(&value)[idx];
|
||||
}
|
||||
|
||||
float16_t& operator[](int idx) {
|
||||
return reinterpret_cast<float16_t*>(&value)[idx];
|
||||
}
|
||||
|
||||
float16x8_t value;
|
||||
};
|
||||
|
||||
#define DEFINE_NEON_UNARY_OP(name, op) \
|
||||
inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
|
||||
return Simd<float16_t, N>{op(a.value)}; \
|
||||
}
|
||||
|
||||
DEFINE_NEON_UNARY_OP(abs, vabsq_f16)
|
||||
DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)
|
||||
DEFINE_NEON_UNARY_OP(floor, vrndmq_f16)
|
||||
DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)
|
||||
DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)
|
||||
DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)
|
||||
DEFINE_NEON_UNARY_OP(rint, vrndnq_f16)
|
||||
|
||||
#define DEFINE_NEON_BINARY_OP(name, op) \
|
||||
inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
||||
return op(a.value, b.value); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
|
||||
return op(a.value, Simd<float16_t, N>(b).value); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
|
||||
return op(Simd<float16_t, N>(a).value, b.value); \
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {
|
||||
auto out = vceqzq_f16(v.value);
|
||||
return Simd<uint16_t, N>(*(uint16_t*)&out);
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {
|
||||
return vnegq_f16(v.value);
|
||||
}
|
||||
|
||||
DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)
|
||||
DEFINE_NEON_BINARY_OP(minimum, vminq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
|
||||
DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
|
||||
|
||||
#define DEFINE_NEON_COMPARISON(Op, op) \
|
||||
template <typename T> \
|
||||
Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
|
||||
auto out = op(a.value, Simd<float16_t, N>(b).value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
} \
|
||||
template <typename T> \
|
||||
Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
|
||||
auto out = op(Simd<float16_t, N>(a).value, b.value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
} \
|
||||
inline Simd<bool, N> operator Op( \
|
||||
Simd<float16_t, N> a, Simd<float16_t, N> b) { \
|
||||
auto out = op(a.value, b.value); \
|
||||
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
|
||||
}
|
||||
|
||||
DEFINE_NEON_COMPARISON(==, vceqq_f16)
|
||||
DEFINE_NEON_COMPARISON(>=, vcgeq_f16)
|
||||
DEFINE_NEON_COMPARISON(<=, vcleq_f16)
|
||||
DEFINE_NEON_COMPARISON(>, vcgtq_f16)
|
||||
DEFINE_NEON_COMPARISON(<, vcltq_f16)
|
||||
|
||||
template <typename T>
|
||||
Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
|
||||
return !(a == b);
|
||||
}
|
||||
template <typename T>
|
||||
Simd<bool, N> operator!=(T a, Simd<float16_t, N> b) {
|
||||
return !(a == b);
|
||||
}
|
||||
inline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
inline Simd<float16_t, N> operator||(
|
||||
Simd<float16_t, N> a,
|
||||
Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) || (b != 0));
|
||||
}
|
||||
inline Simd<float16_t, N> operator&&(
|
||||
Simd<float16_t, N> a,
|
||||
Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
template <typename T>
|
||||
Simd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {
|
||||
return Simd<uint16_t, N>((a != 0) && (b != 0));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
|
||||
return v != v;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Simd<float16_t, N>
|
||||
clamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {
|
||||
return minimum(maximum(v, min), max);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {
|
||||
return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);
|
||||
}
|
||||
|
||||
template <typename MaskT>
|
||||
Simd<float16_t, N>
|
||||
select(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {
|
||||
return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
|
||||
}
|
||||
|
||||
// Reductions
|
||||
inline float16_t max(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t min(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpmin_f16(y, y);
|
||||
y = vpmin_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
inline float16_t sum(Simd<float16_t, N> x) {
|
||||
float16x4_t y;
|
||||
y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
|
||||
y = vpadd_f16(y, y);
|
||||
y = vpadd_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::simd
|
||||
4
mlx/backend/common/simd/simd.h
Normal file
4
mlx/backend/common/simd/simd.h
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/math.h"
|
||||
#include "mlx/backend/common/simd/type.h"
|
||||
7
mlx/backend/common/simd/type.h
Normal file
7
mlx/backend/common/simd/type.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/simd/base_simd.h"
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include "mlx/backend/common/simd/accelerate_simd.h"
|
||||
#endif
|
||||
@@ -4,61 +4,107 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
constexpr bool same_t = std::is_same_v<T, AccT>;
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int N = in.shape().back();
|
||||
int M = in.data_size() / N;
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
AccT maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
||||
: maximum;
|
||||
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
AccT normalizer = 0;
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
AccT expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = expv;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
if constexpr (same_t) {
|
||||
store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_in_ptr = in_ptr;
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (same_t) {
|
||||
store(
|
||||
current_out_ptr,
|
||||
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
|
||||
} else {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum) * normalizer;
|
||||
store(current_out_ptr, Simd<T, N>(vexp));
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (same_t) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
auto v = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(v * normalizer);
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
@@ -97,7 +143,7 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
throw std::runtime_error(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
|
||||
@@ -287,7 +287,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -321,7 +321,7 @@ void ArgSort::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Sort::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -355,7 +355,7 @@ void Sort::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
@@ -389,7 +389,7 @@ void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Partition::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
|
||||
@@ -137,7 +137,9 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
}
|
||||
}
|
||||
|
||||
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[SVD::eval] only supports float32.");
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#pragma once
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
285
mlx/backend/common/unary.cpp
Normal file
285
mlx/backend/common/unary.cpp
Normal file
@@ -0,0 +1,285 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/unary_ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto op = detail::Abs{};
|
||||
switch (out.dtype()) {
|
||||
case int8:
|
||||
unary_op<int8_t>(in, out, op);
|
||||
break;
|
||||
case int16:
|
||||
unary_op<int16_t>(in, out, op);
|
||||
break;
|
||||
case int32:
|
||||
unary_op<int32_t>(in, out, op);
|
||||
break;
|
||||
case int64:
|
||||
unary_op<int64_t>(in, out, op);
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, op);
|
||||
break;
|
||||
case float32:
|
||||
unary_op<float>(in, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(in, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Abs] Called on unsigned type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
}
|
||||
|
||||
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cos());
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
}
|
||||
|
||||
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::Erf());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf] Error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[erf_inv] Inverse error function only defined for arrays"
|
||||
" with real floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Exp());
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
}
|
||||
|
||||
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
break;
|
||||
case Base::two:
|
||||
unary_fp(in, out, detail::Log2());
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_fp(in, out, detail::Log10());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
}
|
||||
|
||||
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::LogicalNot());
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Negative());
|
||||
}
|
||||
|
||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
}
|
||||
|
||||
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == bool_) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
unary(in, out, detail::Sign());
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sin());
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
unary(in, out, detail::Square());
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (recip_) {
|
||||
unary_fp(in, out, detail::Rsqrt());
|
||||
} else {
|
||||
unary_fp(in, out, detail::Sqrt());
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tan());
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -38,8 +39,19 @@ void unary_op(const array& a, array& out, Op op) {
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
dst[i] = op(a_ptr[i]);
|
||||
constexpr int N = simd::max_size<T>;
|
||||
size_t size = a.data_size();
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a_ptr)));
|
||||
size -= N;
|
||||
a_ptr += N;
|
||||
dst += N;
|
||||
}
|
||||
while (size > 0) {
|
||||
*dst = op(*a_ptr);
|
||||
size--;
|
||||
dst++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
108
mlx/backend/common/unary_ops.h
Normal file
108
mlx/backend/common/unary_ops.h
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
#define SINGLE() \
|
||||
template <typename T> \
|
||||
T operator()(T x) { \
|
||||
return (*this)(Simd<T, 1>(x)).value; \
|
||||
}
|
||||
|
||||
#define DEFAULT_OP(Op, op) \
|
||||
struct Op { \
|
||||
template <int N, typename T> \
|
||||
Simd<T, N> operator()(Simd<T, N> x) { \
|
||||
return simd::op(x); \
|
||||
} \
|
||||
SINGLE() \
|
||||
};
|
||||
|
||||
DEFAULT_OP(Abs, abs)
|
||||
DEFAULT_OP(ArcCos, acos)
|
||||
DEFAULT_OP(ArcCosh, acosh)
|
||||
DEFAULT_OP(ArcSin, asin)
|
||||
DEFAULT_OP(ArcSinh, asinh)
|
||||
DEFAULT_OP(ArcTan, atan)
|
||||
DEFAULT_OP(ArcTanh, atanh)
|
||||
DEFAULT_OP(Ceil, ceil)
|
||||
DEFAULT_OP(Conjugate, conj)
|
||||
DEFAULT_OP(Cos, cos)
|
||||
DEFAULT_OP(Cosh, cosh)
|
||||
DEFAULT_OP(Erf, erf)
|
||||
DEFAULT_OP(ErfInv, erfinv)
|
||||
DEFAULT_OP(Exp, exp)
|
||||
DEFAULT_OP(Expm1, expm1)
|
||||
DEFAULT_OP(Floor, floor);
|
||||
DEFAULT_OP(Log, log);
|
||||
DEFAULT_OP(Log2, log2);
|
||||
DEFAULT_OP(Log10, log10);
|
||||
DEFAULT_OP(Log1p, log1p);
|
||||
DEFAULT_OP(LogicalNot, operator!)
|
||||
DEFAULT_OP(Negative, operator-)
|
||||
DEFAULT_OP(Round, rint);
|
||||
DEFAULT_OP(Sin, sin)
|
||||
DEFAULT_OP(Sinh, sinh)
|
||||
DEFAULT_OP(Sqrt, sqrt)
|
||||
DEFAULT_OP(Rsqrt, rsqrt)
|
||||
DEFAULT_OP(Tan, tan)
|
||||
DEFAULT_OP(Tanh, tanh)
|
||||
|
||||
struct Imag {
|
||||
template <int N>
|
||||
Simd<float, N> operator()(Simd<complex64_t, N> x) {
|
||||
return simd::imag(x);
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <int N>
|
||||
Simd<float, N> operator()(Simd<complex64_t, N> x) {
|
||||
return simd::real(x);
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
return 1.0f / (1.0f + simd::exp(-x));
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
auto z = Simd<T, N>{0};
|
||||
if constexpr (std::is_unsigned_v<T>) {
|
||||
return x != z;
|
||||
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
||||
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
||||
} else {
|
||||
return simd::select(
|
||||
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
|
||||
}
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
return x * x;
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
Reference in New Issue
Block a user