// Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/kernels/fp16_math.cuh" #include "mlx/backend/cuda/kernels/utils.cuh" namespace mlx::core::cu { struct Abs { template __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x; } else if constexpr (cuda::std::is_same_v) { return {sqrt(cuCrealf(x) * cuCrealf(x) + cuCimagf(x) * cuCimagf(x)), 0}; } else { return abs(x); } } }; struct ArcCos { template __device__ T operator()(T x) { return acos(x); } }; struct ArcCosh { template __device__ T operator()(T x) { return acosh(x); } }; struct ArcSin { template __device__ T operator()(T x) { return asin(x); } }; struct ArcSinh { template __device__ T operator()(T x) { return asinh(x); } }; struct ArcTan { template __device__ T operator()(T x) { return atan(x); } }; struct ArcTanh { template __device__ T operator()(T x) { return atanh(x); } }; struct BitwiseInvert { template __device__ T operator()(T x) { return ~x; } }; struct Ceil { template __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; } else { return ceil(x); } } }; struct Conjugate { __device__ cuComplex operator()(cuComplex x) { return {cuCrealf(x), -cuCimagf(x)}; } }; struct Cos { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return { cos(cuCrealf(x)) * cosh(cuCimagf(x)), -sin(cuCrealf(x)) * sinh(cuCimagf(x))}; } else { return cos(x); } } }; struct Cosh { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return { cosh(cuCrealf(x)) * cos(cuCimagf(x)), sinh(cuCrealf(x)) * sin(cuCimagf(x))}; } else { return cosh(x); } } }; struct Erf { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return erf(__half2float(x)); } else if constexpr (cuda::std::is_same_v) { return erf(__bfloat162float(x)); } else { return erf(x); } } }; struct ErfInv { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return erfinv(__half2float(x)); } else if constexpr (cuda::std::is_same_v) { return erfinv(__bfloat162float(x)); } else { return erfinv(x); } } }; struct Exp { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { auto m = exp(cuCrealf(x)); return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))}; } else { return exp(x); } } }; struct Expm1 { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return expm1(__half2float(x)); } else if constexpr (cuda::std::is_same_v) { return expm1(__bfloat162float(x)); } else { return expm1(x); } } }; struct Floor { template __device__ T operator()(T x) { if constexpr (cuda::std::is_integral_v) { return x; } else { return floor(x); } } }; struct Imag { __device__ float operator()(cuComplex x) { return cuCimagf(x); } }; struct Log { template __device__ T operator()(T x) { return log(x); } }; struct Log2 { template __device__ T operator()(T x) { return log2(x); } }; struct Log10 { template __device__ T operator()(T x) { return log10(x); } }; struct Log1p { template __device__ T operator()(T x) { return log1p(x); } }; struct LogicalNot { __device__ bool operator()(bool x) { return !x; } }; struct Negative { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return 0 - x; } else { return -x; } } }; struct Real { __device__ float operator()(cuComplex x) { return cuCrealf(x); } }; struct Round { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return {rint(cuCrealf(x)), rint(cuCimagf(x))}; } else { return rint(x); } } }; struct Rsqrt { template __device__ T operator()(T x) { return rsqrt(x); } }; struct Sigmoid { template __device__ T operator()(T x) { T y = 1 / (1 + exp(-abs(x))); return (x < 0) ? 1 - y : y; } }; struct Sign { template __device__ T operator()(T x) { if constexpr (cuda::std::is_unsigned_v) { return x != 0; } else if constexpr (cuda::std::is_same_v) { if (cuCrealf(x) == 0 && cuCimagf(x) == 0) { return x; } else { return x / Abs()(x); } } else if constexpr (cuda::std::is_same_v) { return static_cast((x > T(0.f)) - (x < T(0.f))); } else { return (x > T(0)) - (x < T(0)); } } }; struct Sin { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return { sin(cuCrealf(x)) * cosh(cuCimagf(x)), cos(cuCrealf(x)) * sinh(cuCimagf(x))}; } else { return sin(x); } } }; struct Sinh { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { return { sinh(cuCrealf(x)) * cos(cuCimagf(x)), cosh(cuCrealf(x)) * sin(cuCimagf(x))}; } else { return sinh(x); } } }; struct Square { template __device__ T operator()(T x) { return x * x; } }; struct Sqrt { template __device__ T operator()(T x) { return sqrt(x); } }; struct Tan { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { float tan_a = tan(cuCrealf(x)); float tanh_b = tanh(cuCimagf(x)); float t1 = tan_a * tanh_b; float denom = 1. + t1 * t1; return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; } else { return tan(x); } } }; struct Tanh { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { float tanh_a = tanh(cuCrealf(x)); float tan_b = tan(cuCimagf(x)); float t1 = tanh_a * tan_b; float denom = 1. + t1 * t1; return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; } else { return tanh(x); } } }; } // namespace mlx::core::cu