mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
use fp32 for testing, add more complex ops (#2322)
This commit is contained in:
@@ -27,6 +27,8 @@ struct ArcCos {
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
@@ -41,6 +43,8 @@ struct ArcSin {
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
@@ -55,6 +59,8 @@ struct ArcTan {
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x);
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
@@ -261,13 +267,6 @@ struct Round {
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
@@ -333,6 +332,29 @@ struct Sqrt {
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
auto xr = cuCrealf(x);
|
||||
auto xi = cuCimagf(x);
|
||||
if (xr == 0.0f && xi == 0.0f) {
|
||||
return {0.0f, 0.0f};
|
||||
}
|
||||
auto r = cuCrealf(Abs{}(x));
|
||||
auto a = sqrt((r + xr) / 2.0f);
|
||||
auto b_abs = sqrt((r - xr) / 2.0f);
|
||||
auto b = copysign(b_abs, xi);
|
||||
return {a, b};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
__device__ cuComplex operator()(cuComplex x) {
|
||||
return 1.0f / Sqrt{}(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
@@ -365,4 +387,22 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0, 1.0};
|
||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto ix = i * x;
|
||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@@ -342,8 +342,6 @@ void LayerNormVJP::eval_gpu(
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b.
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
|
@@ -304,7 +304,6 @@ void RMSNormVJP::eval_gpu(
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
gw.set_data(allocator::malloc(gw.nbytes()));
|
||||
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
|
@@ -20,38 +20,35 @@ namespace cu {
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
std::is_same_v<Op, Sign>) {
|
||||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
|
||||
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
||||
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
|
||||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
|
||||
std::is_same_v<Op, Sigmoid>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
|
||||
std::is_same_v<Op, Square>) {
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
|
||||
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Conjugate>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
|
||||
}
|
||||
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
|
||||
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
|
||||
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
|
||||
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
|
||||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
|
||||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
|
||||
std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;
|
||||
|
Reference in New Issue
Block a user