mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add conjugate operator (#1100)
* cpu and gpu impl * add mx.conj and array.conj() --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
		| @@ -1584,5 +1584,13 @@ void init_array(nb::module_& m) { | ||||
|           "stream"_a = nb::none(), | ||||
|           R"pbdoc( | ||||
|             Extract a diagonal or construct a diagonal matrix. | ||||
|         )pbdoc"); | ||||
|         )pbdoc") | ||||
|       .def( | ||||
|           "conj", | ||||
|           [](const array& a, StreamOrDevice s) { | ||||
|             return mlx::core::conjugate(to_array(a), s); | ||||
|           }, | ||||
|           nb::kw_only(), | ||||
|           "stream"_a = nb::none(), | ||||
|           "See :func:`conj`."); | ||||
| } | ||||
|   | ||||
| @@ -3027,6 +3027,40 @@ void init_ops(nb::module_& m) { | ||||
|           inclusive (bool): The i-th element of the output includes the i-th | ||||
|             element of the input. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "conj", | ||||
|       [](const ScalarOrArray& a, StreamOrDevice s) { | ||||
|         return mlx::core::conjugate(to_array(a), s); | ||||
|       }, | ||||
|       nb::arg(), | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Return the elementwise complex conjugate of the input. | ||||
|         Alias for `mx.conjugate`. | ||||
|  | ||||
|         Args: | ||||
|           a (array): Input array | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "conjugate", | ||||
|       [](const ScalarOrArray& a, StreamOrDevice s) { | ||||
|         return mlx::core::conjugate(to_array(a), s); | ||||
|       }, | ||||
|       nb::arg(), | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Return the elementwise complex conjugate of the input. | ||||
|         Alias for `mx.conj`. | ||||
|  | ||||
|         Args: | ||||
|           a (array): Input array | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "convolve", | ||||
|       [](const array& a, | ||||
|   | ||||
| @@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             "log1p", | ||||
|             "floor", | ||||
|             "ceil", | ||||
|             "conjugate", | ||||
|         ] | ||||
|  | ||||
|         x = 0.5 | ||||
| @@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                 out_np = getattr(np, op)(a_np, b_np) | ||||
|                 self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) | ||||
|  | ||||
|     def test_conjugate(self): | ||||
|         shape = (3, 5, 7) | ||||
|         a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape) | ||||
|         a = a.astype(np.complex64) | ||||
|         ops = ["conjugate", "conj"] | ||||
|         for op in ops: | ||||
|             out_mlx = getattr(mx, op)(mx.array(a)) | ||||
|             out_np = getattr(np, op)(a) | ||||
|             self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) | ||||
|         out_mlx = mx.array(a).conj() | ||||
|         out_np = a.conj() | ||||
|         self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron