Log for complex numbers in Metal (#2025)

* Log for complex numbers in Metal

* fix log2
This commit is contained in:
Awni Hannun
2025-03-30 17:04:38 -07:00
committed by GitHub
parent b2d2b37888
commit 28f39e9038
4 changed files with 48 additions and 1 deletions

View File

@@ -845,6 +845,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected))
a = mx.array(1.0) + 1j * mx.array(2.0)
result = mx.log(a)
expected = np.log(np.array(a))
self.assertTrue(np.allclose(result, expected))
def test_log2(self):
a = mx.array([0.5, 1, 2, 10, 16])
result = mx.log2(a)
@@ -852,6 +857,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected))
a = mx.array(1.0) + 1j * mx.array(2.0)
result = mx.log2(a)
expected = np.log2(np.array(a))
self.assertTrue(np.allclose(result, expected))
def test_log10(self):
a = mx.array([0.1, 1, 10, 20, 100])
result = mx.log10(a)
@@ -859,6 +869,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(result, expected))
a = mx.array(1.0) + 1j * mx.array(2.0)
result = mx.log10(a)
expected = np.log10(np.array(a))
self.assertTrue(np.allclose(result, expected))
def test_exp(self):
a = mx.array([0, 0.5, -0.5, 5])
result = mx.exp(a)