#pragma once #include #include #include #include #include #include #include "mlx/backend/cpu/simd/base_simd.h" // There seems to be a bug in simd/base_simd.h // __XROS_2_0 is not defined, the expression evaluates // to true instead of false setting the SIMD library // higher than it should be even on macOS < 15 #if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \ __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \ __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ __TV_OS_VERSION_MIN_REQUIRED >= 180000 #define MLX_SIMD_LIBRARY_VERSION 6 #else #define MLX_SIMD_LIBRARY_VERSION 5 #endif namespace mlx::core::simd { // Apple simd namespace namespace asd = ::simd; // This indirection is needed to remap certain types to ones that accelerate // SIMD can handle template struct ScalarT { using v = T; }; template struct ScalarT { using v = char; }; template struct ScalarT { using v = char; }; template struct ScalarT { using v = unsigned long; }; template struct ScalarT { using v = long; }; template struct Simd { static constexpr int size = N; using scalar_t = typename ScalarT::v; Simd() {} template Simd(Simd other) : value(asd::convert(other.value)) {} template Simd(U v) : value(v){}; Simd(Simd x, Simd y) { value = asd::make::packed_t>( x.value, y.value); }; T operator[](int idx) const { return reinterpret_cast(&value)[idx]; } T& operator[](int idx) { return reinterpret_cast(&value)[idx]; } typename asd::Vector::packed_t value; }; // Values chosen based on benchmarks on M3 Max // TODO: consider choosing these more optimally template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 8; template <> inline constexpr int max_size = 4; template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 16; template <> inline constexpr int max_size = 8; template <> inline constexpr int max_size = 4; template <> inline constexpr int max_size = 8; template <> inline constexpr int max_size = 4; #define SIMD_DEFAULT_UNARY(name, op) \ template \ Simd name(Simd v) { \ return op(v.value); \ } SIMD_DEFAULT_UNARY(abs, asd::abs) SIMD_DEFAULT_UNARY(floor, asd::floor) SIMD_DEFAULT_UNARY(acos, asd::acos) SIMD_DEFAULT_UNARY(acosh, asd::acosh) SIMD_DEFAULT_UNARY(asin, asd::asin) SIMD_DEFAULT_UNARY(asinh, asd::asinh) SIMD_DEFAULT_UNARY(atan, asd::atan) SIMD_DEFAULT_UNARY(atanh, asd::atanh) SIMD_DEFAULT_UNARY(ceil, asd::ceil) SIMD_DEFAULT_UNARY(cosh, asd::cosh) SIMD_DEFAULT_UNARY(expm1, asd::expm1) SIMD_DEFAULT_UNARY(log, asd::log) SIMD_DEFAULT_UNARY(log2, asd::log2) SIMD_DEFAULT_UNARY(log10, asd::log10) SIMD_DEFAULT_UNARY(log1p, asd::log1p) SIMD_DEFAULT_UNARY(rint, asd::rint) SIMD_DEFAULT_UNARY(sinh, asd::sinh) SIMD_DEFAULT_UNARY(sqrt, asd::sqrt) SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt) SIMD_DEFAULT_UNARY(recip, asd::recip) SIMD_DEFAULT_UNARY(tan, asd::tan) SIMD_DEFAULT_UNARY(tanh, asd::tanh) template Simd operator-(Simd v) { return -v.value; } template Simd operator~(Simd v) { return ~v.value; } template Simd isnan(Simd v) { return asd::convert(v.value != v.value); } // No simd_boolN in accelerate, use int8_t instead template Simd operator!(Simd v) { return asd::convert(!v.value); } #define SIMD_DEFAULT_BINARY(OP) \ template \ Simd operator OP(Simd x, U y) { \ return asd::convert::scalar_t>(x.value OP y); \ } \ template \ Simd operator OP(T1 x, Simd y) { \ return asd::convert::scalar_t>(x OP y.value); \ } \ template \ Simd operator OP(Simd x, Simd y) { \ return asd::convert::scalar_t>(x.value OP y.value); \ } SIMD_DEFAULT_BINARY(+) SIMD_DEFAULT_BINARY(-) SIMD_DEFAULT_BINARY(/) SIMD_DEFAULT_BINARY(*) SIMD_DEFAULT_BINARY(<<) SIMD_DEFAULT_BINARY(>>) SIMD_DEFAULT_BINARY(|) SIMD_DEFAULT_BINARY(^) SIMD_DEFAULT_BINARY(&) SIMD_DEFAULT_BINARY(&&) SIMD_DEFAULT_BINARY(||) #define SIMD_DEFAULT_COMPARISONS(OP) \ template \ Simd operator OP(Simd a, U b) { \ return asd::convert(a.value OP b); \ } \ template \ Simd operator OP(T a, Simd b) { \ return asd::convert(a OP b.value); \ } \ template \ Simd operator OP(Simd a, Simd b) { \ return asd::convert(a.value OP b.value); \ } SIMD_DEFAULT_COMPARISONS(>) SIMD_DEFAULT_COMPARISONS(<) SIMD_DEFAULT_COMPARISONS(>=) SIMD_DEFAULT_COMPARISONS(<=) SIMD_DEFAULT_COMPARISONS(==) SIMD_DEFAULT_COMPARISONS(!=) template Simd clz(Simd x) { auto a = *(uint32x4_t*)(&x); auto b = *((uint32x4_t*)(&x) + 1); a = vclzq_u32(a); b = vclzq_u32(b); return asd::make_uint8(a, b); } template Simd atan2(Simd a, Simd b) { return asd::atan2(a.value, b.value); } template Simd maximum(Simd a, Simd b) { auto out = Simd(asd::max(a.value, b.value)); if constexpr (!std::is_integral_v) { out = select(isnan(b), b, select(isnan(a), a, out)); } return out; } template Simd minimum(Simd a, Simd b) { auto out = Simd(asd::min(a.value, b.value)); if constexpr (!std::is_integral_v) { out = select(isnan(b), b, select(isnan(a), a, out)); } return out; } template Simd remainder(Simd a, Simd b) { Simd r; if constexpr (!std::is_integral_v) { r = asd::remainder(a.value, b.value); } else { r = a - b * (a / b); } if constexpr (std::is_signed_v) { auto mask = r != 0 && (r < 0 != b < 0); r = select(mask, r + b, r); } return r; } template Simd select(Simd mask, Simd x, Simd y) { static_assert(std::is_same_v); if constexpr (sizeof(T1) == 1) { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } else if constexpr (sizeof(T1) == 2) { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } else if constexpr (sizeof(T1) == 4) { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } else { return asd::bitselect(y.value, x.value, asd::convert(mask.value)); } } template Simd pow(Simd base, Simd exp) { if constexpr (!std::is_integral_v) { return asd::pow(base.value, exp.value); } else { Simd res = 1; // Raising an integer to a negative power is undefined if (any(exp < 0)) { return 0; } while (any(exp > 0)) { res = select((exp & 1) != 0, res * base, res); base = select(exp > 0, base * base, base); exp = exp >> 1; } return res; } } template Simd clamp(Simd v, Simd min, Simd max) { return asd::clamp(v.value, min.value, max.value); } template Simd fma(Simd x, Simd y, U z) { return asd::muladd(x.value, y.value, Simd(z).value); } // Reductions template bool all(Simd x) { return asd::all(x.value); } template bool any(Simd x) { return asd::any(x.value); } template T sum(Simd x) { return asd::reduce_add(x.value); } template T max(Simd x) { return asd::reduce_max(x.value); } template T min(Simd x) { return asd::reduce_min(x.value); } template T prod(Simd x) { auto ptr = (T*)&x; auto lhs = load(ptr); auto rhs = load(ptr + N / 2); return prod(lhs * rhs); } } // namespace mlx::core::simd #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include "mlx/backend/cpu/simd/accelerate_fp16_simd.h" #endif