From 1a4f4c5ea66d6be1f5568bba03170c1cd71f78d6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 Feb 2024 06:12:53 -0800 Subject: [PATCH] Refactor CPU compile preamble (#708) * refactor cpu preamble * fix include order * fix some issues' * fixes for linux * try to fix includes * add back warning suppression * more linux fixes --- mlx/backend/accelerate/primitives.cpp | 5 +- mlx/backend/common/CMakeLists.txt | 31 + mlx/backend/common/binary.cpp | 131 +- mlx/backend/common/compiled.cpp | 10 +- mlx/backend/common/compiled_preamble.h | 1122 +----------------- mlx/backend/common/erf.h | 11 - mlx/backend/common/make_compiled_preamble.sh | 34 + mlx/backend/common/ops.h | 591 +++++++++ mlx/backend/common/primitives.cpp | 89 +- mlx/backend/common/unary.h | 53 - mlx/types/complex.h | 4 +- tests/ops_tests.cpp | 6 +- 12 files changed, 732 insertions(+), 1355 deletions(-) delete mode 100644 mlx/backend/common/erf.h create mode 100644 mlx/backend/common/make_compiled_preamble.sh create mode 100644 mlx/backend/common/ops.h diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 4cccd35ae..e147b5888 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,11 +81,8 @@ void Abs::eval_cpu(const std::vector& inputs, array& out) { } else if (in.dtype() == int32 && in.flags().contiguous) { set_unary_output_data(in, out); vDSP_vabsi(in.data(), 1, out.data(), 1, in.data_size()); - } else if (is_unsigned(in.dtype())) { - // No-op for unsigned types - out.copy_shared_buffer(in); } else { - unary(in, out, AbsOp()); + eval(inputs, out); } } diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index b25001f2c..38a9819e5 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,3 +1,33 @@ + +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(CLANG TRUE) +endif() + +add_custom_command( + OUTPUT compiled_preamble.cpp + COMMAND /bin/bash + ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_CXX_COMPILER} + ${CMAKE_SOURCE_DIR} + ${CLANG} + + DEPENDS make_compiled_preamble.sh + compiled_preamble.h + ${CMAKE_SOURCE_DIR}/mlx/types/half_types.h + ${CMAKE_SOURCE_DIR}/mlx/types/fp16.h + ${CMAKE_SOURCE_DIR}/mlx/types/bf16.h + ${CMAKE_SOURCE_DIR}/mlx/types/complex.h + ops.h +) + +add_custom_target( + cpu_compiled_preamble + DEPENDS compiled_preamble.cpp +) + +add_dependencies(mlx cpu_compiled_preamble) + target_sources( mlx PRIVATE @@ -19,4 +49,5 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp + ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 855e8467b..ec7097797 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -7,6 +7,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary_two.h" +#include "mlx/backend/common/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -73,7 +74,7 @@ void Add::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x + y; }); + binary(a, b, out, detail::Add()); } void DivMod::eval( @@ -135,106 +136,56 @@ void Divide::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x / y; }); + binary(a, b, out, detail::Divide()); } -struct RemainderFn { - template - std::enable_if_t & !std::is_signed_v, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } - - template - std::enable_if_t & std::is_signed_v, T> operator()( - T numerator, - T denominator) { - auto r = numerator % denominator; - if (r != 0 && (r < 0 != denominator < 0)) - r += denominator; - return r; - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - auto r = std::fmod(numerator, denominator); - if (r != 0 && (r < 0 != denominator < 0)) { - r += denominator; - } - return r; - } - - complex64_t operator()(complex64_t numerator, complex64_t denominator) { - return numerator % denominator; - } -}; - void Remainder::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, RemainderFn{}); + binary(a, b, out, detail::Remainder()); } void Equal::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (equal_nan_) { - comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) { - return x == y || (std::isnan(x) && std::isnan(y)); - }); + comparison_op(inputs[0], inputs[1], out, detail::NaNEqual()); } else { - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; }); + comparison_op(inputs[0], inputs[1], out, detail::Equal()); } } void Greater::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; }); + comparison_op(inputs[0], inputs[1], out, detail::Greater()); } void GreaterEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; }); + comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual()); } void Less::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; }); + comparison_op(inputs[0], inputs[1], out, detail::Less()); } void LessEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; }); + comparison_op(inputs[0], inputs[1], out, detail::LessEqual()); } void LogAddExp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - auto op = [](auto x, auto y) { - constexpr float inf = std::numeric_limits::infinity(); - auto maxval = (x > y) ? x : y; - auto minval = (x > y) ? y : x; - return (minval == -inf || maxval == inf) - ? maxval - : static_cast( - maxval + std::log1p(std::exp(minval - maxval))); - }; if (is_floating_point(out.dtype())) { if (out.dtype() == float32) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == float16) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else if (out.dtype() == bfloat16) { - binary_op(a, b, out, op); + binary_op(a, b, out, detail::LogAddExp()); } else { std::ostringstream err; err << "[logaddexp] Does not support " << out.dtype(); @@ -251,84 +202,40 @@ void Maximum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - - if (is_floating_point(out.dtype())) { - binary(a, b, out, [](auto x, auto y) { - if (std::isnan(x)) { - return x; - } - return (x > y) ? x : y; - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); - } + binary(a, b, out, detail::Maximum()); } void Minimum::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - if (is_floating_point(out.dtype())) { - binary(a, b, out, [](auto x, auto y) { - if (std::isnan(x)) { - return x; - } - return (x < y) ? x : y; - }); - } else { - binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); - } + binary(a, b, out, detail::Minimum()); } void Multiply::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x * y; }); + binary(a, b, out, detail::Multiply()); } void NotEqual::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); - comparison_op( - inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; }); + comparison_op(inputs[0], inputs[1], out, detail::NotEqual()); } -struct PowerFn { - template - std::enable_if_t, T> operator()(T base, T exp) { - return std::pow(base, exp); - } - - template - std::enable_if_t, T> operator()(T base, T exp) { - if (exp < 0) { - throw std::invalid_argument( - "Integers cannot be raise to negative powers"); - } - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } -}; - void Power::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, PowerFn{}); + binary(a, b, out, detail::Power()); } void Subtract::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - binary(a, b, out, [](auto x, auto y) { return x - y; }); + binary(a, b, out, detail::Subtract()); } } // namespace mlx::core diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 914b85ae3..52bcac4fa 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -178,7 +178,13 @@ void* compile( build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared " << source_file_path << " -o " << shared_lib_path; std::string build_command_str = build_command.str(); - system(build_command_str.c_str()); + auto return_code = system(build_command_str.c_str()); + if (return_code) { + std::ostringstream msg; + msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name + << " with error code " << return_code << "." << std::endl; + throw std::runtime_error(msg.str()); + } } // load library @@ -421,7 +427,7 @@ void Compiled::eval_cpu( // If it doesn't exist, compile it if (fn_ptr == nullptr) { std::ostringstream kernel; - kernel << preamble << std::endl; + kernel << get_kernel_preamble() << std::endl; kernel << "extern \"C\" {" << std::endl; build_kernel( kernel, diff --git a/mlx/backend/common/compiled_preamble.h b/mlx/backend/common/compiled_preamble.h index 8ccaa8bd7..84b77d29d 100644 --- a/mlx/backend/common/compiled_preamble.h +++ b/mlx/backend/common/compiled_preamble.h @@ -1,1121 +1,11 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2023-24 Apple Inc. -const std::string preamble = R"( -#include -#include -#include - -#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC - -#include -typedef __fp16 float16_t; - -#else - -#define ADD_HALF_BINOPS -#include -#include -#include -#include - -#define __MLX_HALF_NAN__ 0x7D00 - - -namespace { -union float_bits_fp16 { - float f; - uint32_t u; -}; -} // namespace - -struct _MLX_Float16 { - uint16_t bits_; - - // Default constructor - _MLX_Float16() = default; - - // Default copy constructor - _MLX_Float16(_MLX_Float16 const&) = default; - - // Appease std::vector for being special - _MLX_Float16& operator=(std::vector::reference x) { - bits_ = x; - return *this; - } - - _MLX_Float16& operator=(const float& x) { - return (*this = _MLX_Float16(x)); - } - - // From float32 - _MLX_Float16(const float& x) : bits_(0) { - // Conversion following - // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h - - // Union - float_bits_fp16 in; - - // Take fp32 bits - in.f = x; - - // Find and take sign bit - uint32_t x_sign_32 = in.u & uint32_t(0x80000000); - uint16_t x_sign_16 = (x_sign_32 >> 16); - - if (std::isnan(x)) { - bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); - } else { - // Union - float_bits_fp16 inf_scale, zero_scale, magic_bits; - - // Find exponent bits and take the max supported by half - uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); - uint32_t max_expo_32 = uint32_t(0x38800000); - x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; - x_expo_32 += uint32_t(15) << 23; - - // Handle scaling to inf as needed - inf_scale.u = uint32_t(0x77800000); - zero_scale.u = uint32_t(0x08800000); - - // Combine with magic and let addition do rounding - magic_bits.u = x_expo_32; - magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; - - // Take the lower 5 bits of the exponent - uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); - - // Collect the lower 12 bits which have the mantissa - uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); - - // Combine sign, exp and mantissa - bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); - } - } - - // To float32 - operator float() const { - // Conversion following - // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h - - // Union - float_bits_fp16 out; - - uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); - uint32_t base = (bits_ << 16); - uint32_t two_base = base + base; - - uint32_t denorm_max = 1u << 27; - if (two_base < denorm_max) { - out.u = uint32_t(126) << 23; // magic mask - out.u |= (two_base >> 17); // Bits from fp16 - out.f -= 0.5f; // magic bias - } else { - out.u = uint32_t(0xE0) << 23; // exponent offset - out.u += (two_base >> 4); // Bits from fp16 - float out_unscaled = out.f; // Store value - out.u = uint32_t(0x7800000); // exponent scale - out.f *= out_unscaled; - } - - // Add sign - out.u |= x_sign_32; - - return out.f; - } -}; - -#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - inline otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ - inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -// Operators -#define half_binop(__op__, __operator__) \ - half_binop_base( \ - __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ - half_binop_helper(__op__, __operator__, float, float, float); \ - half_binop_helper(__op__, __operator__, double, double, double); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ - half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); - -half_binop(+, operator+); -half_binop(-, operator-); -half_binop(*, operator*); -half_binop(/, operator/); - -#undef half_binop - -// Comparison ops -#define half_compop(__op__, __operator__) \ - half_binop_base( \ - __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ - half_binop_helper(__op__, __operator__, bool, float, float); \ - half_binop_helper(__op__, __operator__, bool, double, double); \ - half_binop_helper(__op__, __operator__, bool, int32_t, float); \ - half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - half_binop_helper(__op__, __operator__, bool, int64_t, float); \ - half_binop_helper(__op__, __operator__, bool, uint64_t, float); - -half_compop(>, operator>); -half_compop(<, operator<); -half_compop(>=, operator>=); -half_compop(<=, operator<=); -half_compop(==, operator==); -half_compop(!=, operator!=); - -#undef half_compop - -// Negative -inline _MLX_Float16 operator-(_MLX_Float16 lhs) { - return -static_cast(lhs); -} - -// Inplace ops -#define half_inplace_op(__op__, __operator__) \ - inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } \ - inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } - -half_inplace_op(+, operator+=); -half_inplace_op(-, operator-=); -half_inplace_op(*, operator*=); -half_inplace_op(/, operator/=); - -#undef half_inplace_op - -// Bitwise ops - -#define half_bitop(__op__, __operator__) \ - inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ - _MLX_Float16 out; \ - out.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return out; \ - } \ - inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ - _MLX_Float16 out; \ - out.bits_ = lhs.bits_ __op__ rhs; \ - return out; \ - } \ - inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ - _MLX_Float16 out; \ - out.bits_ = lhs __op__ rhs.bits_; \ - return out; \ - } - -half_bitop(|, operator|); -half_bitop(&, operator&); -half_bitop(^, operator^); - -#undef half_bitop - -#define half_inplace_bitop(__op__, __operator__) \ - inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return lhs; \ - } \ - inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs; \ - return lhs; \ - } - -half_inplace_bitop(|, operator|=); -half_inplace_bitop(&, operator&=); -half_inplace_bitop(^, operator^=); - -#undef half_inplace_bitop - -typedef struct _MLX_Float16 float16_t; - -#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC -#ifdef __ARM_FEATURE_BF16 - -#include -typedef __bf16 bfloat16_t; - -#else - -#define ADD_HALF_BINOPS -#include -#include -#include -#include - -#define __MLX_BFLOAT_NAN__ 0x7FC0 - - -namespace { -union float_bits_bf16 { - float f; - uint32_t u; -}; -} // namespace - -struct _MLX_BFloat16 { - uint16_t bits_; - - // Default constructor - _MLX_BFloat16() = default; - - // Default copy constructor - _MLX_BFloat16(_MLX_BFloat16 const&) = default; - - // Appease std::vector for being special - _MLX_BFloat16& operator=(std::vector::reference x) { - bits_ = x; - return *this; - } - - _MLX_BFloat16& operator=(const float& x) { - return (*this = _MLX_BFloat16(x)); - } - - // From float32 - _MLX_BFloat16(const float& x) { - if (std::isnan(x)) { - bits_ = __MLX_BFLOAT_NAN__; - } else { - // Union - float_bits_bf16 in; - - // Take bits - in.f = x; - - // Round to nearest even - in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); - - // Take upper 16 bits - bits_ = in.u >> 16; - } - } - - // To float32 - operator float() const { - // Union - float_bits_bf16 out; - - // Upper 16 bits are the data and lower 16 bits are 0s - out.u = ((uint32_t)bits_) << 16; - - return out.f; - } -}; - -#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - inline otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ - inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -// Operators -#define bfloat_binop(_op_, _operator_) \ - bfloat_binop_base( \ - _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(_op_, _operator_, float, float, float); \ - bfloat_binop_helper(_op_, _operator_, double, double, double); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); - -bfloat_binop(+, operator+); -bfloat_binop(-, operator-); -bfloat_binop(*, operator*); -bfloat_binop(/, operator/); - -#undef bfloat_binop - -// Comparison ops -#define bfloat_compop(__op__, __operator__) \ - bfloat_binop_base( \ - __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(__op__, __operator__, bool, float, float); \ - bfloat_binop_helper(__op__, __operator__, bool, double, double); \ - bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); - -bfloat_compop(>, operator>); -bfloat_compop(<, operator<); -bfloat_compop(>=, operator>=); -bfloat_compop(<=, operator<=); -bfloat_compop(==, operator==); -bfloat_compop(!=, operator!=); - -#undef bfloat_compop - -// Negative -inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { - return -static_cast(lhs); -} - -// Inplace ops -#define bfloat_inplace_op(__op__, __operator__) \ - inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } \ - inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ - lhs = lhs __op__ rhs; \ - return lhs; \ - } - -bfloat_inplace_op(+, operator+=); -bfloat_inplace_op(-, operator-=); -bfloat_inplace_op(*, operator*=); -bfloat_inplace_op(/, operator/=); - -#undef bfloat_inplace_op - -// Bitwise ops - -#define bfloat_bitop(__op__, __operator__) \ - inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ - _MLX_BFloat16 out; \ - out.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return out; \ - } \ - inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ - _MLX_BFloat16 out; \ - out.bits_ = lhs.bits_ __op__ rhs; \ - return out; \ - } \ - inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ - _MLX_BFloat16 out; \ - out.bits_ = lhs __op__ rhs.bits_; \ - return out; \ - } - -bfloat_bitop(|, operator|); -bfloat_bitop(&, operator&); -bfloat_bitop(^, operator^); - -#undef bfloat_bitop - -#define bfloat_inplace_bitop(__op__, __operator__) \ - inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ - return lhs; \ - } \ - inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ - lhs.bits_ = lhs.bits_ __op__ rhs; \ - return lhs; \ - } - -bfloat_inplace_bitop(|, operator|=); -bfloat_inplace_bitop(&, operator&=); -bfloat_inplace_bitop(^, operator^=); - -#undef bfloat_inplace_bitop - -typedef struct _MLX_BFloat16 bfloat16_t; - -#endif // __ARM_FEATURE_BF16 - -#ifdef ADD_HALF_BINOPS +#pragma once // clang-format off -#define fp16_bf16_binop_helper(__op__, __operator__) \ - inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } - -fp16_bf16_binop_helper(+, operator+) -fp16_bf16_binop_helper(-, operator-) -fp16_bf16_binop_helper(*, operator*) -fp16_bf16_binop_helper(/, operator/) +#include "mlx/types/half_types.h" +#include "mlx/types/complex.h" +#include "mlx/backend/common/ops.h" // clang-format on -#endif - - -struct complex64_t; - -template -static constexpr bool can_convert_to_complex64 = - !std::is_same_v && std::is_convertible_v; - -struct complex64_t : public std::complex { - complex64_t(float v, float u) : std::complex(v, u){}; - complex64_t(std::complex v) : std::complex(v){}; - - template < - typename T, - typename = typename std::enable_if>::type> - complex64_t(T x) : std::complex(x){}; - - operator float() const { - return real(); - }; -}; - -inline bool operator>=(const complex64_t& a, const complex64_t& b) { - return (a.real() > b.real()) || - (a.real() == b.real() && a.imag() >= b.imag()); -} - -inline bool operator>(const complex64_t& a, const complex64_t& b) { - return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); -} - -inline complex64_t operator%(complex64_t a, complex64_t b) { - auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); - auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); - if (real != 0 && (real < 0 != b.real() < 0)) - real += b.real(); - if (imag != 0 && (imag < 0 != b.imag() < 0)) - imag += b.imag(); - return {real, imag}; -} - -inline bool operator<=(const complex64_t& a, const complex64_t& b) { - return operator>=(b, a); -} - -inline bool operator<(const complex64_t& a, const complex64_t& b) { - return operator>(b, a); -} - -inline complex64_t operator-(const complex64_t& v) { - return -static_cast>(v); -} - -// clang-format off -#define complex_binop_helper(_op_, _operator_, itype) \ - inline complex64_t _operator_(itype x, const complex64_t& y) { \ - return static_cast(x) _op_ y; \ - } \ - inline complex64_t _operator_(const complex64_t& x, itype y) { \ - return x _op_ static_cast(y); \ - } - -#define complex_binop(_op_, _operator_) \ - inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ - return x _op_ static_cast>(y); \ - } \ - inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ - return static_cast>(x) _op_ y; \ - } \ - inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ - return static_cast>(x) \ - _op_ static_cast>(y); \ - } \ - complex_binop_helper(_op_, _operator_, bool) \ - complex_binop_helper(_op_, _operator_, uint32_t) \ - complex_binop_helper(_op_, _operator_, uint64_t) \ - complex_binop_helper(_op_, _operator_, int32_t) \ - complex_binop_helper(_op_, _operator_, int64_t) \ - complex_binop_helper(_op_, _operator_, float16_t) \ - complex_binop_helper(_op_, _operator_, bfloat16_t) \ - complex_binop_helper(_op_, _operator_, float) -// clang-format on - -complex_binop(+, operator+) - -typedef union { - int i; - float f; -} IntOrFloat; - -inline float fast_exp(float x) { - x *= 1.442695; // multiply with log_2(e) - float ipart, fpart; - IntOrFloat epart; - x = std::max(-80.f, std::min(x, 80.f)); - ipart = std::floor(x + 0.5); - fpart = x - ipart; - - x = 1.535336188319500e-4f; - x = x * fpart + 1.339887440266574e-3f; - x = x * fpart + 9.618437357674640e-3f; - x = x * fpart + 5.550332471162809e-2f; - x = x * fpart + 2.402264791363012e-1f; - x = x * fpart + 6.931472028550421e-1f; - x = x * fpart + 1.000000000000000f; - - // generate 2**ipart in the floating point representation using integer - // bitshifting - epart.i = (int(ipart) + 127) << 23; - - return epart.f * x; -} - -float fast_erf(float a) { - float r, s, t, u; - t = std::abs(a); - s = a * a; - if (t > 0.927734375f) { - // maximum error 0.99527 ulp - r = std::fma( - -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 - u = std::fma( - -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 - r = std::fma(r, s, u); - r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 - r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 - r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 - r = std::fma(r, t, -t); - // TODO, replace with expm1 when implemented - r = 1.0f - std::exp(r); - r = std::copysign(r, a); - } else { - // maximum error 0.98929 ulp - r = -5.96761703e-4f; // -0x1.38e000p-11 - r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 - r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 - r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 - r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 - r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 - r = std::fma(r, a, a); - } - return r; -} - -float fast_erfinv(float a) { - auto t = std::fma(a, 0.0f - a, 1.0f); - t = std::log(t); - float p; - if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 - p = 3.03697567e-10f; // 0x1.4deb44p-32 - p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 - p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 - p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 - p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 - p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 - p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 - p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 - p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 - } else { // maximum ulp error = 2.35002 - p = 5.43877832e-9f; // 0x1.75c000p-28 - p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 - p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 - p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 - p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 - p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 - p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 - p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 - p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 - p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 - } - return a * p; -} - -struct Abs { - template - T operator()(T x) { - return std::abs(x); - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct ArcCos { - template - T operator()(T x) { - return std::acos(x); - }; -}; - -struct ArcCosh { - template - T operator()(T x) { - return std::acosh(x); - }; -}; - -struct ArcSin { - template - T operator()(T x) { - return std::asin(x); - }; -}; - -struct ArcSinh { - template - T operator()(T x) { - return std::asinh(x); - }; -}; - -struct ArcTan { - template - T operator()(T x) { - return std::atan(x); - }; -}; - -struct ArcTanh { - template - T operator()(T x) { - return std::atanh(x); - }; -}; - -struct Ceil { - template - T operator()(T x) { - return std::ceil(x); - }; - int8_t operator()(int8_t x) { - return x; - }; - int16_t operator()(int16_t x) { - return x; - }; - int32_t operator()(int32_t x) { - return x; - }; - int64_t operator()(int64_t x) { - return x; - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct Cos { - template - T operator()(T x) { - return std::cos(x); - }; -}; - -struct Cosh { - template - T operator()(T x) { - return std::cosh(x); - }; -}; - -struct Erf { - template - T operator()(T x) { - return static_cast(fast_erf(static_cast(x))); - }; -}; - -struct ErfInv { - template - T operator()(T x) { - return static_cast(fast_erfinv(static_cast(x))); - }; -}; - -struct Exp { - template - T operator()(T x) { - return fast_exp(x); - }; -}; - -struct Floor { - template - T operator()(T x) { - return std::floor(x); - }; - int8_t operator()(int8_t x) { - return x; - }; - int16_t operator()(int16_t x) { - return x; - }; - int32_t operator()(int32_t x) { - return x; - }; - int64_t operator()(int64_t x) { - return x; - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct Log { - template - T operator()(T x) { - return std::log(x); - }; -}; - -struct Log2 { - template - T operator()(T x) { - return std::log2(x); - }; -}; - -struct Log10 { - template - T operator()(T x) { - return std::log10(x); - }; -}; - -struct Log1p { - template - T operator()(T x) { - return log1p(x); - }; -}; - -struct LogicalNot { - template - T operator()(T x) { - return !x; - }; -}; - -struct Negative { - template - T operator()(T x) { - return -x; - }; -}; - -struct Round { - template - T operator()(T x) { - return std::rint(x); - } - - std::complex operator()(std::complex x) { - return {std::rint(x.real()), std::rint(x.imag())}; - } -}; - -struct Sigmoid { - template - T operator()(T x) { - auto one = static_cast(1.0); - return one / (one + fast_exp(-x)); - } -}; - -struct Sign { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - } - uint8_t operator()(uint8_t x) { - return x != 0; - } - uint16_t operator()(uint16_t x) { - return x != 0; - } - uint32_t operator()(uint32_t x) { - return x != 0; - } - uint64_t operator()(uint64_t x) { - return x != 0; - } -}; - -struct Sin { - template - T operator()(T x) { - return std::sin(x); - }; -}; - -struct Sinh { - template - T operator()(T x) { - return std::sinh(x); - }; -}; - -struct Square { - template - T operator()(T x) { - return x * x; - }; -}; - -struct Sqrt { - template - T operator()(T x) { - return std::sqrt(x); - }; -}; - -struct Rsqrt { - template - T operator()(T x) { - return static_cast(1.0) / std::sqrt(x); - }; -}; - -struct Tan { - template - T operator()(T x) { - return std::tan(x); - }; -}; - -struct Tanh { - template - T operator()(T x) { - return std::tanh(x); - }; -}; - -struct Add { - template - T operator()(T x, T y) { - return x + y; - } -}; - -struct Divide { - template - T operator()(T x, T y) { - return x / y; - } -}; - -struct Remainder { - template - std::enable_if_t & !std::is_signed_v, T> operator()( - T numerator, - T denominator) { - return numerator % denominator; - } - - template - std::enable_if_t & std::is_signed_v, T> operator()( - T numerator, - T denominator) { - auto r = numerator % denominator; - if (r != 0 && (r < 0 != denominator < 0)) - r += denominator; - return r; - } - - template - std::enable_if_t, T> operator()( - T numerator, - T denominator) { - auto r = std::fmod(numerator, denominator); - if (r != 0 && (r < 0 != denominator < 0)) { - r += denominator; - } - return r; - } - - std::complex operator()( - std::complex a, std::complex b) { - auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); - auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); - if (real != 0 && ((real < 0) != (b.real() < 0))) - real += b.real(); - if (imag != 0 && ((imag < 0) != (b.imag() < 0))) - imag += b.imag(); - return {real, imag}; - } -}; - -struct Equal { - template - bool operator()(T x, T y) { - return x == y; - } -}; - -struct NaNEqual { - template - bool operator()(T x, T y) { - return x == y || (std::isnan(x) && std::isnan(y)); - } -}; - -struct Greater { - template - bool operator()(T x, T y) { - return x > y; - } -}; - -struct GreaterEqual { - template - bool operator()(T x, T y) { - return x >= y; - } -}; - -struct Less { - template - bool operator()(T x, T y) { - return x < y; - } -}; - -struct LessEqual { - template - bool operator()(T x, T y) { - return x <= y; - } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - constexpr float inf = std::numeric_limits::infinity(); - auto maxval = (x > y) ? x : y; - auto minval = (x > y) ? y : x; - return (minval == -inf || maxval == inf) - ? maxval - : static_cast( - maxval + std::log1p(fast_exp(minval - maxval))); - }; -}; - -struct Maximum { - template - std::enable_if_t, T> operator()(T x, T y) { - return (x > y) ? x : y; - } - - template - std::enable_if_t, T> operator()(T x, T y) { - if (std::isnan(x)) { - return x; - } - return (x > y) ? x : y; - } -}; - -struct Minimum { - template - std::enable_if_t, T> operator()(T x, T y) { - return x < y ? x : y; - } - - template - std::enable_if_t, T> operator()(T x, T y) { - if (std::isnan(x)) { - return x; - } - return x < y ? x : y; - } -}; - -struct Multiply { - template - T operator()(T x, T y) { - return x * y; - } -}; - -struct NotEqual { - template - bool operator()(T x, T y) { - return x != y; - } -}; - -struct Power { - template - std::enable_if_t, T> operator()(T base, T exp) { - return std::pow(base, exp); - } - - template - std::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } -}; - -struct Subtract { - template - T operator()(T x, T y) { - return x - y; - } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { - return x && y; - }; -}; - -struct LogicalOr { - template - T operator()(T x, T y) { - return x || y; - }; -}; -)"; +const char* get_kernel_preamble(); diff --git a/mlx/backend/common/erf.h b/mlx/backend/common/erf.h deleted file mode 100644 index a175a0c43..000000000 --- a/mlx/backend/common/erf.h +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright © 2023 Apple Inc. - -namespace mlx::core { - -/* Approximation to the inverse error function. - * Based on code from: - * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 - */ -float erfinv(float a); - -} // namespace mlx::core diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh new file mode 100644 index 000000000..687f4cfc7 --- /dev/null +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# +# This script generates a C++ function that provides the CPU +# code for use with kernel generation. +# +# Copyright © 2023-24 Apple Inc. + + +OUTPUT_FILE=$1 +GCC=$2 +SRCDIR=$3 +CLANG=$4 + +if [ $CLANG = "TRUE" ]; then + read -r -d '' INCLUDES <<- EOM + #include + #include + #include + #include +EOM + +fi + +CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null) + +cat << EOF > "$OUTPUT_FILE" +const char* get_kernel_preamble() { +return R"preamble( +$INCLUDES +$CONTENT +using namespace mlx::core::detail; +)preamble"; +} +EOF diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h new file mode 100644 index 000000000..8b2d7ab58 --- /dev/null +++ b/mlx/backend/common/ops.h @@ -0,0 +1,591 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once +#include +#include +#include + +namespace mlx::core::detail { + +typedef union { + int i; + float f; +} IntOrFloat; + +inline float fast_exp(float x) { + if (x == -std::numeric_limits::infinity()) { + return 0.0f; + } else if (x == std::numeric_limits::infinity() || std::isnan(x)) { + return x; + } + x *= 1.442695; // multiply with log_2(e) + float ipart, fpart; + IntOrFloat epart; + x = std::max(-80.f, std::min(x, 80.f)); + ipart = std::floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = x * fpart + 1.339887440266574e-3f; + x = x * fpart + 9.618437357674640e-3f; + x = x * fpart + 5.550332471162809e-2f; + x = x * fpart + 2.402264791363012e-1f; + x = x * fpart + 6.931472028550421e-1f; + x = x * fpart + 1.000000000000000f; + + // generate 2**ipart in the floating point representation using integer + // bitshifting + epart.i = (int(ipart) + 127) << 23; + + return epart.f * x; +} + +inline float fast_erf(float a) { + float r, s, t, u; + t = std::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = std::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = std::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = std::fma(r, s, u); + r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = std::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - std::exp(r); + r = std::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = std::fma(r, a, a); + } + return r; +} + +inline float fast_erfinv(float a) { + auto t = std::fma(a, 0.0f - a, 1.0f); + t = std::log(t); + float p; + if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} + +struct Abs { + template + T operator()(T x) { + return std::abs(x); + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return std::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return std::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return std::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return std::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return std::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return std::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return std::ceil(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return std::cos(x); + }; +}; + +struct Cosh { + template + T operator()(T x) { + return std::cosh(x); + }; +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(fast_erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(fast_erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return fast_exp(x); + }; + + complex64_t operator()(complex64_t x) { + return std::exp(x); + } +}; + +struct Floor { + template + T operator()(T x) { + return std::floor(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return std::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return std::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return std::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return std::rint(x); + } + + complex64_t operator()(complex64_t x) { + return {std::rint(x.real()), std::rint(x.imag())}; + } +}; + +struct Sigmoid { + template + T operator()(T x) { + auto one = static_cast(1.0); + return one / (one + fast_exp(-x)); + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + } + uint8_t operator()(uint8_t x) { + return x != 0; + } + uint16_t operator()(uint16_t x) { + return x != 0; + } + uint32_t operator()(uint32_t x) { + return x != 0; + } + uint64_t operator()(uint64_t x) { + return x != 0; + } +}; + +struct Sin { + template + T operator()(T x) { + return std::sin(x); + }; +}; + +struct Sinh { + template + T operator()(T x) { + return std::sinh(x); + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return std::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return static_cast(1.0) / std::sqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return std::tan(x); + }; +}; + +struct Tanh { + template + T operator()(T x) { + return std::tanh(x); + }; +}; + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + std::enable_if_t & !std::is_signed_v, T> operator()( + T numerator, + T denominator) { + return numerator % denominator; + } + + template + std::enable_if_t & std::is_signed_v, T> operator()( + T numerator, + T denominator) { + auto r = numerator % denominator; + if (r != 0 && (r < 0 != denominator < 0)) + r += denominator; + return r; + } + + template + std::enable_if_t, T> operator()( + T numerator, + T denominator) { + auto r = std::fmod(numerator, denominator); + if (r != 0 && (r < 0 != denominator < 0)) { + r += denominator; + } + return r; + } + + complex64_t operator()(complex64_t numerator, complex64_t denominator) { + return numerator % denominator; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (std::isnan(x) && std::isnan(y)); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct Maximum { + template + std::enable_if_t, T> operator()(T x, T y) { + return (x > y) ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return (x > y) ? x : y; + } +}; + +struct Minimum { + template + std::enable_if_t, T> operator()(T x, T y) { + return x < y ? x : y; + } + + template + std::enable_if_t, T> operator()(T x, T y) { + if (std::isnan(x)) { + return x; + } + return x < y ? x : y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + constexpr float inf = std::numeric_limits::infinity(); + auto maxval = Maximum()(x, y); + auto minval = Minimum()(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : static_cast( + maxval + std::log1p(fast_exp(minval - maxval))); + }; +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } +}; + +struct Power { + template + std::enable_if_t, T> operator()(T base, T exp) { + return std::pow(base, exp); + } + + template + std::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; + +} // namespace mlx::core::detail diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 37c61761a..a1e99d7c7 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -10,7 +10,7 @@ #include "mlx/backend/common/arange.h" #include "mlx/backend/common/binary.h" #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/erf.h" +#include "mlx/backend/common/ops.h" #include "mlx/backend/common/threefry.h" #include "mlx/backend/common/unary.h" #include "mlx/backend/common/utils.h" @@ -26,7 +26,7 @@ void Abs::eval(const std::vector& inputs, array& out) { // No-op for unsigned types out.copy_shared_buffer(in); } else { - unary(in, out, AbsOp()); + unary(in, out, detail::Abs()); } } @@ -38,7 +38,7 @@ void ArcCos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::acos(x); }); + unary_fp(in, out, detail::ArcCos()); } else { throw std::invalid_argument( "[arccos] Cannot compute inverse cosine of elements in array" @@ -50,7 +50,7 @@ void ArcCosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::acosh(x); }); + unary_fp(in, out, detail::ArcCosh()); } else { throw std::invalid_argument( "[arccosh] Cannot compute inverse hyperbolic cosine of elements in" @@ -62,7 +62,7 @@ void ArcSin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::asin(x); }); + unary_fp(in, out, detail::ArcSin()); } else { throw std::invalid_argument( "[arcsin] Cannot compute inverse sine of elements in array" @@ -74,7 +74,7 @@ void ArcSinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::asinh(x); }); + unary_fp(in, out, detail::ArcSinh()); } else { throw std::invalid_argument( "[arcsinh] Cannot compute inverse hyperbolic sine of elements in" @@ -86,7 +86,7 @@ void ArcTan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::atan(x); }); + unary_fp(in, out, detail::ArcTan()); } else { throw std::invalid_argument( "[arctan] Cannot compute inverse tangent of elements in array" @@ -98,7 +98,7 @@ void ArcTanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::atanh(x); }); + unary_fp(in, out, detail::ArcTanh()); } else { throw std::invalid_argument( "[arctanh] Cannot compute inverse hyperbolic tangent of elements in" @@ -172,7 +172,7 @@ void Ceil::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, [](auto x) { return std::ceil(x); }); + unary_fp(in, out, detail::Ceil()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -212,7 +212,7 @@ void Cos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::cos(x); }); + unary_fp(in, out, detail::Cos()); } else { throw std::invalid_argument( "[cos] Cannot compute cosine of elements in array" @@ -224,7 +224,7 @@ void Cosh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::cosh(x); }); + unary_fp(in, out, detail::Cosh()); } else { throw std::invalid_argument( "[cosh] Cannot compute hyperbolic cosine of elements in array" @@ -256,17 +256,13 @@ void Erf::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - unary_op(in, out, [](auto x) { return std::erf(x); }); + unary_op(in, out, detail::Erf()); break; case float16: - unary_op(in, out, [](auto x) { - return static_cast(std::erf(static_cast(x))); - }); + unary_op(in, out, detail::Erf()); break; case bfloat16: - unary_op(in, out, [](auto x) { - return static_cast(std::erf(static_cast(x))); - }); + unary_op(in, out, detail::Erf()); break; default: throw std::invalid_argument( @@ -280,17 +276,13 @@ void ErfInv::eval(const std::vector& inputs, array& out) { const auto& in = inputs[0]; switch (out.dtype()) { case float32: - unary_op(in, out, [](auto x) { return erfinv(x); }); + unary_op(in, out, detail::ErfInv()); break; case float16: - unary_op(in, out, [](auto x) { - return static_cast(erfinv(static_cast(x))); - }); + unary_op(in, out, detail::ErfInv()); break; case bfloat16: - unary_op(in, out, [](auto x) { - return static_cast(erfinv(static_cast(x))); - }); + unary_op(in, out, detail::ErfInv()); break; default: throw std::invalid_argument( @@ -302,9 +294,8 @@ void ErfInv::eval(const std::vector& inputs, array& out) { void Exp::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; - if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::exp(x); }); + unary_fp(in, out, detail::Exp()); } else { throw std::invalid_argument( "[exp] Cannot exponentiate elements in array" @@ -316,7 +307,7 @@ void Floor::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, [](auto x) { return std::floor(x); }); + unary_fp(in, out, detail::Floor()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -344,13 +335,13 @@ void Log::eval(const std::vector& inputs, array& out) { if (is_floating_point(out.dtype())) { switch (base_) { case Base::e: - unary_fp(in, out, [](auto x) { return std::log(x); }); + unary_fp(in, out, detail::Log()); break; case Base::two: - unary_fp(in, out, [](auto x) { return std::log2(x); }); + unary_fp(in, out, detail::Log2()); break; case Base::ten: - unary_fp(in, out, [](auto x) { return std::log10(x); }); + unary_fp(in, out, detail::Log10()); break; } } else { @@ -364,7 +355,7 @@ void Log1p::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::log1p(x); }); + unary_fp(in, out, detail::Log1p()); } else { throw std::invalid_argument( "[log1p] Cannot compute log of elements in array with" @@ -375,27 +366,27 @@ void Log1p::eval(const std::vector& inputs, array& out) { void LogicalNot::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return !x; }); + unary(in, out, detail::LogicalNot()); } void LogicalAnd::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalAnd requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, [](auto x, auto y) { return x && y; }); + binary(in1, in2, out, detail::LogicalAnd()); } void LogicalOr::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); // LogicalOr requires two input arrays auto& in1 = inputs[0]; auto& in2 = inputs[1]; - binary(in1, in2, out, [](auto x, auto y) { return x || y; }); + binary(in1, in2, out, detail::LogicalOr()); } void Negative::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return -x; }); + unary(in, out, detail::Negative()); } void Pad::eval(const std::vector& inputs, array& out) { @@ -498,7 +489,7 @@ void Round::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (not is_integral(in.dtype())) { - unary_fp(in, out, RoundOp()); + unary_fp(in, out, detail::Round()); } else { // No-op integer types out.copy_shared_buffer(in); @@ -509,11 +500,7 @@ void Sigmoid::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - auto sigmoid_op = [](auto x) { - auto one = static_cast(1.0); - return one / (one + std::exp(-x)); - }; - unary_fp(in, out, sigmoid_op); + unary_fp(in, out, detail::Sigmoid()); } else { throw std::invalid_argument( "[sigmoid] Cannot sigmoid of elements in array with" @@ -527,7 +514,7 @@ void Sign::eval(const std::vector& inputs, array& out) { if (in.dtype() == bool_) { out.copy_shared_buffer(in); } else { - unary(in, out, SignOp()); + unary(in, out, detail::Sign()); } } @@ -535,7 +522,7 @@ void Sin::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::sin(x); }); + unary_fp(in, out, detail::Sin()); } else { throw std::invalid_argument( "[sin] Cannot compute sine of elements in array" @@ -547,7 +534,7 @@ void Sinh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::sinh(x); }); + unary_fp(in, out, detail::Sinh()); } else { throw std::invalid_argument( "[sinh] Cannot compute hyperbolic sine of elements in array" @@ -656,18 +643,16 @@ void Split::eval( void Square::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - unary(in, out, [](auto x) { return x * x; }); + unary(in, out, detail::Square()); } void Sqrt::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; if (recip_) { - unary_fp(in, out, [](auto x) { - return static_cast(1.0) / sqrt(x); - }); + unary_fp(in, out, detail::Rsqrt()); } else { - unary_fp(in, out, [](auto x) { return sqrt(x); }); + unary_fp(in, out, detail::Sqrt()); } } @@ -680,7 +665,7 @@ void Tan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::tan(x); }); + unary_fp(in, out, detail::Tan()); } else { throw std::invalid_argument( "[tan] Cannot compute tangent of elements in array" @@ -692,7 +677,7 @@ void Tanh::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; if (is_floating_point(out.dtype())) { - unary_fp(in, out, [](auto x) { return std::tanh(x); }); + unary_fp(in, out, detail::Tanh()); } else { throw std::invalid_argument( "[tanh] Cannot compute hyperbolic tangent of elements in array" diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index 7fdcbeb77..28c4f0f4a 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -11,59 +11,6 @@ namespace mlx::core { namespace { -struct AbsOp { - template - T operator()(T x) { - return std::abs(x); - } - uint8_t operator()(uint8_t x) { - return x; - } - uint16_t operator()(uint16_t x) { - return x; - } - uint32_t operator()(uint32_t x) { - return x; - } - uint64_t operator()(uint64_t x) { - return x; - } - bool operator()(bool x) { - return x; - } -}; - -struct SignOp { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - } - - uint8_t operator()(uint8_t x) { - return x != 0; - } - uint16_t operator()(uint16_t x) { - return x != 0; - } - uint32_t operator()(uint32_t x) { - return x != 0; - } - uint64_t operator()(uint64_t x) { - return x != 0; - } -}; - -struct RoundOp { - template - T operator()(T x) { - return std::rint(x); - } - - complex64_t operator()(complex64_t x) { - return {std::rint(x.real()), std::rint(x.imag())}; - } -}; - void set_unary_output_data(const array& in, array& out) { if (in.is_donatable() && in.itemsize() == out.itemsize()) { out.copy_shared_buffer(in); diff --git a/mlx/types/complex.h b/mlx/types/complex.h index 46f4310f9..f8a607766 100644 --- a/mlx/types/complex.h +++ b/mlx/types/complex.h @@ -38,9 +38,9 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) { inline complex64_t operator%(complex64_t a, complex64_t b) { auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); - if (real != 0 && (real < 0 != b.real() < 0)) + if (real != 0 && ((real < 0) != (b.real() < 0))) real += b.real(); - if (imag != 0 && (imag < 0 != b.imag() < 0)) + if (imag != 0 && ((imag < 0) != (b.imag() < 0))) imag += b.imag(); return {real, imag}; } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index e52c1294f..41db064be 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1002,7 +1002,7 @@ TEST_CASE("test arithmetic unary ops") { CHECK_EQ(exp(x).item(), 1.0); x = array(2.0); - CHECK_EQ(exp(x).item(), std::exp(2.0f)); + CHECK_EQ(exp(x).item(), doctest::Approx(std::exp(2.0f))); CHECK(array_equal(exp(array({})), array({})).item()); @@ -1012,7 +1012,7 @@ TEST_CASE("test arithmetic unary ops") { // Integer input type x = array(2); CHECK_EQ(x.dtype(), int32); - CHECK_EQ(exp(x).item(), std::exp(2.0f)); + CHECK_EQ(exp(x).item(), doctest::Approx(std::exp(2.0f))); // Input is irregularly strided x = broadcast_to(array(1.0f), {2, 2, 2}); @@ -1020,7 +1020,7 @@ TEST_CASE("test arithmetic unary ops") { x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); - CHECK(array_equal(exp(x), expected).item()); + CHECK(allclose(exp(x), expected).item()); } // Test sine