Add conjugate operator (#1100)

* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron
2024-05-10 07:22:20 -07:00
committed by GitHub
parent 8bd6bfa4b5
commit 2e158cf6d0
17 changed files with 143 additions and 11 deletions

View File

@@ -158,6 +158,12 @@ struct Cosh {
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return complex64_t{x.real, -x.imag};
}
};
struct Erf {
template <typename T>
T operator()(T x) {

View File

@@ -94,6 +94,7 @@ instantiate_unary_float(tanh, Tanh)
instantiate_unary_float(round, Round)
instantiate_unary_all(abs, complex64, complex64_t, Abs)
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
instantiate_unary_all(cos, complex64, complex64_t, Cos)
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
instantiate_unary_all(exp, complex64, complex64_t, Exp)

View File

@@ -588,6 +588,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_op(inputs, out, "conj");
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}