mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
GPU mx.sign for complex64 (#1326)
This commit is contained in:
parent
63ae767232
commit
99bb7d3a58
@ -373,6 +373,10 @@ struct Sign {
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return x == complex64_t(0) ? x : x / std::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
|
@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
||||
|
@ -308,6 +308,14 @@ struct Sign {
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
if (x == complex64_t(0)) {
|
||||
return x;
|
||||
}
|
||||
return x /
|
||||
(complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag);
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
|
@ -760,6 +760,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = np.sign(a, dtype=np.float32)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0])
|
||||
b = mx.array([-4.0, -3.0, 1.0, 0.0, 3.0])
|
||||
c = a + b * 1j
|
||||
result = mx.sign(c)
|
||||
# np.sign differs in NumPy 1 and 2 so
|
||||
# we manually implement the NumPy 2 version here.
|
||||
expected = c / np.abs(c)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
def test_logical_not(self):
|
||||
a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0])
|
||||
result = mx.logical_not(a)
|
||||
|
Loading…
Reference in New Issue
Block a user