mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal * fix log2
This commit is contained in:
		| @@ -845,6 +845,11 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         a = mx.array(1.0) + 1j * mx.array(2.0) | ||||
|         result = mx.log(a) | ||||
|         expected = np.log(np.array(a)) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|     def test_log2(self): | ||||
|         a = mx.array([0.5, 1, 2, 10, 16]) | ||||
|         result = mx.log2(a) | ||||
| @@ -852,6 +857,11 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         a = mx.array(1.0) + 1j * mx.array(2.0) | ||||
|         result = mx.log2(a) | ||||
|         expected = np.log2(np.array(a)) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|     def test_log10(self): | ||||
|         a = mx.array([0.1, 1, 10, 20, 100]) | ||||
|         result = mx.log10(a) | ||||
| @@ -859,6 +869,11 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         a = mx.array(1.0) + 1j * mx.array(2.0) | ||||
|         result = mx.log10(a) | ||||
|         expected = np.log10(np.array(a)) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|     def test_exp(self): | ||||
|         a = mx.array([0, 0.5, -0.5, 5]) | ||||
|         result = mx.exp(a) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun