#pragma once #include #include #include #include namespace mlx::core::simd { template struct Simd; template static constexpr int max_size = 1; template struct Simd { static constexpr int size = 1; T value; Simd() {} template Simd(Simd v) : value(v.value) {} template Simd(U v) : value(v) {} }; template Simd load(const T* x) { return *(Simd*)x; } template void store(T* dst, Simd x) { // Maintain invariant that bool is either 0 or 1 as // simd comparison ops set all bits in the result to 1 if constexpr (std::is_same_v && N > 1) { x = x & 1; } *(Simd*)dst = x; } template constexpr bool is_complex = false; template constexpr bool is_complex().real())>> = true; template Simd rint(Simd in) { if constexpr (is_complex) { return Simd{ T{std::rint(in.value.real()), std::rint(in.value.imag())}}; } else { return Simd{std::rint(in.value)}; } } template Simd rsqrt(Simd in) { return T(1.0) / sqrt(in); } template Simd recip(Simd in) { return T(1.0) / in; } #define DEFAULT_UNARY(name, op) \ template \ Simd name(Simd in) { \ return op(in.value); \ } DEFAULT_UNARY(operator-, std::negate{}) DEFAULT_UNARY(operator!, std::logical_not{}) DEFAULT_UNARY(abs, std::abs) DEFAULT_UNARY(acos, std::acos) DEFAULT_UNARY(acosh, std::acosh) DEFAULT_UNARY(asin, std::asin) DEFAULT_UNARY(asinh, std::asinh) DEFAULT_UNARY(atan, std::atan) DEFAULT_UNARY(atanh, std::atanh) DEFAULT_UNARY(ceil, std::ceil) DEFAULT_UNARY(conj, std::conj) DEFAULT_UNARY(cosh, std::cosh) DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log2, std::log2) DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log1p, std::log1p) DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tanh, std::tanh) template auto real(Simd in) -> Simd { return std::real(in.value); } template auto imag(Simd in) -> Simd { return std::imag(in.value); } template Simd isnan(Simd in) { return std::isnan(in.value); } #define DEFAULT_BINARY(OP) \ template \ auto operator OP(Simd a, Simd b) \ ->Simd { \ return a.value OP b.value; \ } \ template \ auto operator OP(T1 a, Simd b)->Simd { \ return a OP b.value; \ } \ template \ auto operator OP(Simd a, T2 b)->Simd { \ return a.value OP b; \ } DEFAULT_BINARY(+) DEFAULT_BINARY(-) DEFAULT_BINARY(*) DEFAULT_BINARY(/) DEFAULT_BINARY(<<) DEFAULT_BINARY(>>) DEFAULT_BINARY(|) DEFAULT_BINARY(^) DEFAULT_BINARY(&) DEFAULT_BINARY(&&) DEFAULT_BINARY(||) template Simd remainder(Simd a_, Simd b_) { T a = a_.value; T b = b_.value; T r; if constexpr (std::is_integral_v) { r = a % b; } else { r = std::remainder(a, b); } if constexpr (std::is_signed_v) { if (r != 0 && (r < 0 != b < 0)) { r += b; } } return r; } template Simd maximum(Simd a_, Simd b_) { T a = a_.value; T b = b_.value; if constexpr (!std::is_integral_v) { if (std::isnan(a)) { return a; } } return (a > b) ? a : b; } template Simd minimum(Simd a_, Simd b_) { T a = a_.value; T b = b_.value; if constexpr (!std::is_integral_v) { if (std::isnan(a)) { return a; } } return (a < b) ? a : b; } template Simd pow(Simd a, Simd b) { T base = a.value; T exp = b.value; if constexpr (!std::is_integral_v) { return std::pow(base, exp); } else { T res = 1; while (exp) { if (exp & 1) { res *= base; } exp >>= 1; base *= base; } return res; } } template Simd atan2(Simd a, Simd b) { return std::atan2(a.value, b.value); } #define DEFAULT_COMPARISONS(OP) \ template \ Simd operator OP(Simd a, Simd b) { \ return a.value OP b.value; \ } \ template \ Simd operator OP(T1 a, Simd b) { \ return a OP b.value; \ } \ template \ Simd operator OP(Simd a, T2 b) { \ return a.value OP b; \ } DEFAULT_COMPARISONS(>) DEFAULT_COMPARISONS(<) DEFAULT_COMPARISONS(>=) DEFAULT_COMPARISONS(<=) DEFAULT_COMPARISONS(==) DEFAULT_COMPARISONS(!=) template Simd select(Simd mask, Simd x, Simd y) { return mask.value ? x.value : y.value; } template Simd clamp(Simd v, Simd min, Simd max) { return std::clamp(v.value, min.value, max.value); } template Simd fma(Simd x, Simd y, U z) { return std::fma(x.value, y.value, Simd(z).value); } // Reductions #define DEFAULT_REDUCTION(name, type) \ template \ type name(Simd x) { \ return x.value; \ } DEFAULT_REDUCTION(max, T) DEFAULT_REDUCTION(min, T) DEFAULT_REDUCTION(sum, T) DEFAULT_REDUCTION(any, bool) DEFAULT_REDUCTION(all, bool) } // namespace mlx::core::simd