mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
use fp32 for testing, add more complex ops (#2322)
This commit is contained in:
@@ -27,6 +27,8 @@ struct ArcCos {
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
@@ -41,6 +43,8 @@ struct ArcSin {
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
@@ -55,6 +59,8 @@ struct ArcTan {
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
@@ -261,13 +267,6 @@ struct Round {
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
@@ -333,6 +332,29 @@ struct Sqrt {
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
auto xr = cuCrealf(x);
|
||||
auto xi = cuCimagf(x);
|
||||
if (xr == 0.0f && xi == 0.0f) {
|
||||
return {0.0f, 0.0f};
|
||||
}
|
||||
auto r = cuCrealf(Abs{}(x));
|
||||
auto a = sqrt((r + xr) / 2.0f);
|
||||
auto b_abs = sqrt((r - xr) / 2.0f);
|
||||
auto b = copysign(b_abs, xi);
|
||||
return {a, b};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
return 1.0f / Sqrt{}(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
@@ -365,4 +387,22 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0, 1.0};
|
||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto ix = i * x;
|
||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
Reference in New Issue
Block a user