// Copyright © 2024 Apple Inc. #pragma once #include #include #include #include "mlx/backend/cpu/simd/simd.h" namespace mlx::core::detail { using namespace mlx::core::simd; #define SINGLE() \ template \ T operator()(T x) { \ return (*this)(Simd(x)).value; \ } #define DEFAULT_OP(Op, op) \ struct Op { \ template \ Simd operator()(Simd x) { \ return simd::op(x); \ } \ SINGLE() \ }; DEFAULT_OP(Abs, abs) DEFAULT_OP(ArcCos, acos) DEFAULT_OP(ArcCosh, acosh) DEFAULT_OP(ArcSin, asin) DEFAULT_OP(ArcSinh, asinh) DEFAULT_OP(ArcTan, atan) DEFAULT_OP(ArcTanh, atanh) DEFAULT_OP(BitwiseInvert, operator~) DEFAULT_OP(Ceil, ceil) DEFAULT_OP(Conjugate, conj) DEFAULT_OP(Cos, cos) DEFAULT_OP(Cosh, cosh) DEFAULT_OP(Erf, erf) DEFAULT_OP(ErfInv, erfinv) DEFAULT_OP(Exp, exp) DEFAULT_OP(Expm1, expm1) DEFAULT_OP(Floor, floor); DEFAULT_OP(Log, log); DEFAULT_OP(Log2, log2); DEFAULT_OP(Log10, log10); DEFAULT_OP(Log1p, log1p); DEFAULT_OP(LogicalNot, operator!) DEFAULT_OP(Negative, operator-) DEFAULT_OP(Round, rint); DEFAULT_OP(Sin, sin) DEFAULT_OP(Sinh, sinh) DEFAULT_OP(Sqrt, sqrt) DEFAULT_OP(Rsqrt, rsqrt) DEFAULT_OP(Tan, tan) DEFAULT_OP(Tanh, tanh) struct Imag { template Simd operator()(Simd x) { return simd::imag(x); } SINGLE() }; struct Real { template Simd operator()(Simd x) { return simd::real(x); } SINGLE() }; struct Sigmoid { template Simd operator()(Simd x) { auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); return simd::select(x < Simd{0}, y, Simd{1} - y); } SINGLE() }; struct Sign { template Simd operator()(Simd x) { auto z = Simd{0}; auto o = Simd{1}; auto m = Simd{-1}; if constexpr (std::is_unsigned_v) { return simd::select(x == z, z, o); } else if constexpr (std::is_same_v) { return simd::select(x == z, x, Simd(x / simd::abs(x))); } else { return simd::select(x < z, m, simd::select(x > z, o, z)); } } SINGLE() }; struct Square { template Simd operator()(Simd x) { return x * x; } SINGLE() }; template Simd fp32_from_bits(Simd x) { return *(Simd*)(&x); } template Simd fp32_to_bits(Simd x) { return *(Simd*)(&x); } struct ToFP8 { template Simd operator()(Simd f) { uint32_t fp8_max = 543 << 21; auto denorm_mask = Simd(141 << 23); Simd f_bits; Simd f32 = f; f_bits = fp32_to_bits(f32); Simd result = 0u; auto sign = f_bits & 0x80000000; f_bits = f_bits ^ sign; auto f_bits_low = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); auto result_low = Simd(f_bits_low - denorm_mask); auto mant_odd = Simd((f_bits >> 20) & 1); auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF); f_bits_high = f_bits_high + Simd(mant_odd); auto result_high = Simd(f_bits_high >> 20); result = select(f_bits < (121 << 23), result_low, result_high); auto result_sat = Simd(0x7E); result = select(f_bits >= fp8_max, result_sat, result); return result | Simd(sign >> 24); } template uint8_t operator()(T x) { return (*this)(Simd(x)).value; } }; struct FromFP8 { template Simd operator()(Simd x) { auto w = Simd(x) << 24; auto sign = w & 0x80000000; auto nonsign = w & 0x7FFFFFFF; auto renorm_shift = clz(nonsign); renorm_shift = simd::select( renorm_shift > Simd{4}, renorm_shift - Simd{4}, Simd{0}); Simd inf_nan_mask = (Simd(nonsign + 0x01000000) >> 8) & 0x7F800000; auto zero_mask = Simd(nonsign - 1) >> 31; auto result = sign | ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask); return fp32_from_bits(result); } float operator()(uint8_t x) { return (*this)(Simd(x)).value; } }; } // namespace mlx::core::detail