GPU mx.sign for complex64 (#1326)

This commit is contained in:
Alex Barron 2024-08-14 07:54:53 -07:00 committed by GitHub
parent 63ae767232
commit 99bb7d3a58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 0 deletions

View File

@ -373,6 +373,10 @@ struct Sign {
uint64_t operator()(uint64_t x) {
return x != 0;
}
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
};
struct Sin {

View File

@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t)
instantiate_unary_all(Cosh, complex64, complex64_t)
instantiate_unary_all(Exp, complex64, complex64_t)
instantiate_unary_all(Negative, complex64, complex64_t)
instantiate_unary_all(Sign, complex64, complex64_t)
instantiate_unary_all(Sin, complex64, complex64_t)
instantiate_unary_all(Sinh, complex64, complex64_t)
instantiate_unary_all(Tan, complex64, complex64_t)

View File

@ -308,6 +308,14 @@ struct Sign {
uint32_t operator()(uint32_t x) {
return x != 0;
};
template <>
complex64_t operator()(complex64_t x) {
if (x == complex64_t(0)) {
return x;
}
return x /
(complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag);
};
};
struct Sin {

View File

@ -760,6 +760,15 @@ class TestOps(mlx_tests.MLXTestCase):
expected = np.sign(a, dtype=np.float32)
self.assertTrue(np.allclose(result, expected))
a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0])
b = mx.array([-4.0, -3.0, 1.0, 0.0, 3.0])
c = a + b * 1j
result = mx.sign(c)
# np.sign differs in NumPy 1 and 2 so
# we manually implement the NumPy 2 version here.
expected = c / np.abs(c)
self.assertTrue(np.allclose(result, expected))
def test_logical_not(self):
a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0])
result = mx.logical_not(a)