Add conjugate operator (#1100)

* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Alex Barron 2024-05-10 07:22:20 -07:00 committed by GitHub
parent 8bd6bfa4b5
commit 2e158cf6d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 143 additions and 11 deletions

View File

@ -38,6 +38,8 @@ Operations
ceil
clip
concatenate
conj
conjugate
convolve
conv1d
conv2d

View File

@ -36,6 +36,7 @@ DEFAULT(BlockSparseMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)

View File

@ -47,6 +47,7 @@ DEFAULT(BlockSparseMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)

View File

@ -209,6 +209,12 @@ struct Ceil {
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return std::conj(x);
}
};
struct Cos {
template <typename T>
T operator()(T x) {

View File

@ -203,6 +203,17 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_fp(in, out, detail::Conjugate());
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);

View File

@ -158,6 +158,12 @@ struct Cosh {
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return complex64_t{x.real, -x.imag};
}
};
struct Erf {
template <typename T>
T operator()(T x) {

View File

@ -94,6 +94,7 @@ instantiate_unary_float(tanh, Tanh)
instantiate_unary_float(round, Round)
instantiate_unary_all(abs, complex64, complex64_t, Abs)
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
instantiate_unary_all(cos, complex64, complex64_t, Cos)
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
instantiate_unary_all(exp, complex64, complex64_t, Exp)

View File

@ -588,6 +588,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_op(inputs, out, "conj");
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

View File

@ -39,6 +39,7 @@ NO_GPU(Broadcast)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate)
NO_GPU(Conjugate)
NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)

View File

@ -23,16 +23,17 @@ bool is_unary(const Primitive& p) {
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) ||
typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) ||
typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) ||
typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) ||
typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) ||
typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) ||
typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) ||
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
typeid(p) == typeid(Expm1));
}
bool is_binary(const Primitive& p) {

View File

@ -4101,6 +4101,15 @@ array number_of_elements(
{a}));
}
array conjugate(const array& a, StreamOrDevice s /* = {} */) {
// Mirror NumPy's behaviour for real input
if (a.dtype() != complex64) {
return a;
}
return array(
a.shape(), a.dtype(), std::make_shared<Conjugate>(to_stream(s)), {a});
}
array bitwise_impl(
const array& a,
const array& b,

View File

@ -1239,6 +1239,8 @@ array number_of_elements(
Dtype dtype = int32,
StreamOrDevice s = {});
array conjugate(const array& a, StreamOrDevice s = {});
/** Bitwise and. */
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
array operator&(const array& a, const array& b);

View File

@ -789,6 +789,14 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
return axis_ == c_other.axis_;
}
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{conjugate(inputs[0], stream())}, axes};
}
array conv_weight_backward_patches(
const array& in,
const array& wt,

View File

@ -620,6 +620,22 @@ class Concatenate : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Conjugate : public UnaryPrimitive {
public:
explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(Conjugate)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Convolution : public UnaryPrimitive {
public:
explicit Convolution(

View File

@ -1584,5 +1584,13 @@ void init_array(nb::module_& m) {
"stream"_a = nb::none(),
R"pbdoc(
Extract a diagonal or construct a diagonal matrix.
)pbdoc");
)pbdoc")
.def(
"conj",
[](const array& a, StreamOrDevice s) {
return mlx::core::conjugate(to_array(a), s);
},
nb::kw_only(),
"stream"_a = nb::none(),
"See :func:`conj`.");
}

View File

@ -3027,6 +3027,40 @@ void init_ops(nb::module_& m) {
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
)pbdoc");
m.def(
"conj",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::conjugate(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the elementwise complex conjugate of the input.
Alias for `mx.conjugate`.
Args:
a (array): Input array
)pbdoc");
m.def(
"conjugate",
[](const ScalarOrArray& a, StreamOrDevice s) {
return mlx::core::conjugate(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the elementwise complex conjugate of the input.
Alias for `mx.conj`.
Args:
a (array): Input array
)pbdoc");
m.def(
"convolve",
[](const array& a,

View File

@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase):
"log1p",
"floor",
"ceil",
"conjugate",
]
x = 0.5
@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
def test_conjugate(self):
shape = (3, 5, 7)
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
a = a.astype(np.complex64)
ops = ["conjugate", "conj"]
for op in ops:
out_mlx = getattr(mx, op)(mx.array(a))
out_np = getattr(np, op)(a)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
out_mlx = mx.array(a).conj()
out_np = a.conj()
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
if __name__ == "__main__":
unittest.main()