mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
GPU mx.sign for complex64 (#1326)
This commit is contained in:
parent
63ae767232
commit
99bb7d3a58
@ -373,6 +373,10 @@ struct Sign {
|
|||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x != 0;
|
return x != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return x == complex64_t(0) ? x : x / std::abs(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sin {
|
struct Sin {
|
||||||
|
@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t)
|
|||||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
instantiate_unary_all(Cosh, complex64, complex64_t)
|
||||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
instantiate_unary_all(Exp, complex64, complex64_t)
|
||||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
instantiate_unary_all(Negative, complex64, complex64_t)
|
||||||
|
instantiate_unary_all(Sign, complex64, complex64_t)
|
||||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
instantiate_unary_all(Sin, complex64, complex64_t)
|
||||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
instantiate_unary_all(Sinh, complex64, complex64_t)
|
||||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
instantiate_unary_all(Tan, complex64, complex64_t)
|
||||||
|
@ -308,6 +308,14 @@ struct Sign {
|
|||||||
uint32_t operator()(uint32_t x) {
|
uint32_t operator()(uint32_t x) {
|
||||||
return x != 0;
|
return x != 0;
|
||||||
};
|
};
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
if (x == complex64_t(0)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x /
|
||||||
|
(complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag);
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sin {
|
struct Sin {
|
||||||
|
@ -760,6 +760,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected = np.sign(a, dtype=np.float32)
|
expected = np.sign(a, dtype=np.float32)
|
||||||
self.assertTrue(np.allclose(result, expected))
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0])
|
||||||
|
b = mx.array([-4.0, -3.0, 1.0, 0.0, 3.0])
|
||||||
|
c = a + b * 1j
|
||||||
|
result = mx.sign(c)
|
||||||
|
# np.sign differs in NumPy 1 and 2 so
|
||||||
|
# we manually implement the NumPy 2 version here.
|
||||||
|
expected = c / np.abs(c)
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
def test_logical_not(self):
|
def test_logical_not(self):
|
||||||
a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0])
|
a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0])
|
||||||
result = mx.logical_not(a)
|
result = mx.logical_not(a)
|
||||||
|
Loading…
Reference in New Issue
Block a user