use fp32 for testing, add more complex ops (#2322)

This commit is contained in:
Awni Hannun
2025-07-01 07:30:00 -07:00
committed by GitHub
parent 3d5e17e507
commit dd4f53db63
6 changed files with 68 additions and 40 deletions

View File

@@ -27,6 +27,8 @@ struct ArcCos {
__device__ T operator()(T x) {
return acos(x);
}
__device__ cuComplex operator()(cuComplex x);
};
struct ArcCosh {
@@ -41,6 +43,8 @@ struct ArcSin {
__device__ T operator()(T x) {
return asin(x);
}
__device__ cuComplex operator()(cuComplex x);
};
struct ArcSinh {
@@ -55,6 +59,8 @@ struct ArcTan {
__device__ T operator()(T x) {
return atan(x);
}
__device__ cuComplex operator()(cuComplex x);
};
struct ArcTanh {
@@ -261,13 +267,6 @@ struct Round {
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
};
struct Sigmoid {
template <typename T>
__device__ T operator()(T x) {
@@ -333,6 +332,29 @@ struct Sqrt {
__device__ T operator()(T x) {
return sqrt(x);
}
__device__ cuComplex operator()(cuComplex x) {
auto xr = cuCrealf(x);
auto xi = cuCimagf(x);
if (xr == 0.0f && xi == 0.0f) {
return {0.0f, 0.0f};
}
auto r = cuCrealf(Abs{}(x));
auto a = sqrt((r + xr) / 2.0f);
auto b_abs = sqrt((r - xr) / 2.0f);
auto b = copysign(b_abs, xi);
return {a, b};
}
};
struct Rsqrt {
template <typename T>
__device__ T operator()(T x) {
return rsqrt(x);
}
__device__ cuComplex operator()(cuComplex x) {
return 1.0f / Sqrt{}(x);
}
};
struct Tan {
@@ -365,4 +387,22 @@ struct Tanh {
}
};
__device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
};
} // namespace mlx::core::cu