mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	GPU mx.sign for complex64 (#1326)
This commit is contained in:
		| @@ -373,6 +373,10 @@ struct Sign { | ||||
|   uint64_t operator()(uint64_t x) { | ||||
|     return x != 0; | ||||
|   } | ||||
|  | ||||
|   complex64_t operator()(complex64_t x) { | ||||
|     return x == complex64_t(0) ? x : x / std::abs(x); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| struct Sin { | ||||
|   | ||||
| @@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t) | ||||
| instantiate_unary_all(Cosh, complex64, complex64_t) | ||||
| instantiate_unary_all(Exp, complex64, complex64_t) | ||||
| instantiate_unary_all(Negative, complex64, complex64_t) | ||||
| instantiate_unary_all(Sign, complex64, complex64_t) | ||||
| instantiate_unary_all(Sin, complex64, complex64_t) | ||||
| instantiate_unary_all(Sinh, complex64, complex64_t) | ||||
| instantiate_unary_all(Tan, complex64, complex64_t) | ||||
|   | ||||
| @@ -308,6 +308,14 @@ struct Sign { | ||||
|   uint32_t operator()(uint32_t x) { | ||||
|     return x != 0; | ||||
|   }; | ||||
|   template <> | ||||
|   complex64_t operator()(complex64_t x) { | ||||
|     if (x == complex64_t(0)) { | ||||
|       return x; | ||||
|     } | ||||
|     return x / | ||||
|         (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); | ||||
|   }; | ||||
| }; | ||||
|  | ||||
| struct Sin { | ||||
|   | ||||
| @@ -760,6 +760,15 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         expected = np.sign(a, dtype=np.float32) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0]) | ||||
|         b = mx.array([-4.0, -3.0, 1.0, 0.0, 3.0]) | ||||
|         c = a + b * 1j | ||||
|         result = mx.sign(c) | ||||
|         # np.sign differs in NumPy 1 and 2 so | ||||
|         # we manually implement the NumPy 2 version here. | ||||
|         expected = c / np.abs(c) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|     def test_logical_not(self): | ||||
|         a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0]) | ||||
|         result = mx.logical_not(a) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron