mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
GPU mx.sign for complex64 (#1326)
This commit is contained in:
@@ -373,6 +373,10 @@ struct Sign {
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return x == complex64_t(0) ? x : x / std::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
|
@@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
||||
|
@@ -308,6 +308,14 @@ struct Sign {
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
if (x == complex64_t(0)) {
|
||||
return x;
|
||||
}
|
||||
return x /
|
||||
(complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag);
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
|
Reference in New Issue
Block a user