GPU mx.sign for complex64 (#1326)

This commit is contained in:
Alex Barron
2024-08-14 07:54:53 -07:00
committed by GitHub
parent 63ae767232
commit 99bb7d3a58
4 changed files with 22 additions and 0 deletions

View File

@@ -760,6 +760,15 @@ class TestOps(mlx_tests.MLXTestCase):
expected = np.sign(a, dtype=np.float32)
self.assertTrue(np.allclose(result, expected))
a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0])
b = mx.array([-4.0, -3.0, 1.0, 0.0, 3.0])
c = a + b * 1j
result = mx.sign(c)
# np.sign differs in NumPy 1 and 2 so
# we manually implement the NumPy 2 version here.
expected = c / np.abs(c)
self.assertTrue(np.allclose(result, expected))
def test_logical_not(self):
a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0])
result = mx.logical_not(a)