mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 19:28:14 +08:00
Add conjugate operator (#1100)
* cpu and gpu impl * add mx.conj and array.conj() --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
@@ -158,6 +158,12 @@ struct Cosh {
|
||||
};
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return complex64_t{x.real, -x.imag};
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -94,6 +94,7 @@ instantiate_unary_float(tanh, Tanh)
|
||||
instantiate_unary_float(round, Round)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
||||
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
||||
|
@@ -588,6 +588,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_op(inputs, out, "conj");
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
Reference in New Issue
Block a user