Add conjugate operator (#1100)

* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-05-10 07:22:20 -07:00
committed by GitHub
parent 8bd6bfa4b5
commit 2e158cf6d0
17 changed files with 143 additions and 11 deletions

View File

@@ -1584,5 +1584,13 @@ void init_array(nb::module_& m) {
"stream"_a = nb::none(),
R"pbdoc(
Extract a diagonal or construct a diagonal matrix.
)pbdoc");
)pbdoc")
.def(
"conj",
[](const array& a, StreamOrDevice s) {
return mlx::core::conjugate(to_array(a), s);
},
nb::kw_only(),
"stream"_a = nb::none(),
"See :func:`conj`.");
}

View File

@@ -3027,6 +3027,40 @@ void init_ops(nb::module_& m) {
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
)pbdoc");
m.def(
"conj",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::conjugate(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the elementwise complex conjugate of the input.
Alias for `mx.conjugate`.
Args:
a (array): Input array
)pbdoc");
m.def(
"conjugate",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::conjugate(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the elementwise complex conjugate of the input.
Alias for `mx.conj`.
Args:
a (array): Input array
)pbdoc");
m.def(
"convolve",
[](const array& a,