From 2e158cf6d044db4230eaaeb81b67b8fd55dccbd5 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Fri, 10 May 2024 07:22:20 -0700 Subject: [PATCH] Add conjugate operator (#1100) * cpu and gpu impl * add mx.conj and array.conj() --------- Co-authored-by: Alex Barron --- docs/src/python/ops.rst | 2 ++ mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/ops.h | 6 ++++ mlx/backend/common/primitives.cpp | 11 ++++++++ mlx/backend/metal/kernels/unary.h | 6 ++++ mlx/backend/metal/kernels/unary.metal | 1 + mlx/backend/metal/primitives.cpp | 11 ++++++++ mlx/backend/no_metal/primitives.cpp | 1 + mlx/compile.cpp | 21 +++++++------- mlx/ops.cpp | 9 ++++++ mlx/ops.h | 2 ++ mlx/primitives.cpp | 8 ++++++ mlx/primitives.h | 16 +++++++++++ python/src/array.cpp | 10 ++++++- python/src/ops.cpp | 34 +++++++++++++++++++++++ python/tests/test_ops.py | 14 ++++++++++ 17 files changed, 143 insertions(+), 11 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 8cd648e31..177332c49 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -38,6 +38,8 @@ Operations ceil clip concatenate + conj + conjugate convolve conv1d conv2d diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 1fecd9ca6..9bf1868c2 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -36,6 +36,7 @@ DEFAULT(BlockSparseMM) DEFAULT(Broadcast) DEFAULT(Ceil) DEFAULT(Concatenate) +DEFAULT(Conjugate) DEFAULT(Copy) DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(Depends) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index a0695feeb..ec5289d6a 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -47,6 +47,7 @@ DEFAULT(BlockSparseMM) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) DEFAULT(Concatenate) +DEFAULT(Conjugate) DEFAULT(Convolution) DEFAULT(Copy) DEFAULT(Cos) diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index 8733dbe74..0fa0bef5a 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -209,6 +209,12 @@ struct Ceil { }; }; +struct Conjugate { + complex64_t operator()(complex64_t x) { + return std::conj(x); + } +}; + struct Cos { template T operator()(T x) { diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 08b0775c8..442b09af0 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -203,6 +203,17 @@ void Concatenate::eval(const std::vector& inputs, array& out) { } } +void Conjugate::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == complex64) { + unary_fp(in, out, detail::Conjugate()); + } else { + throw std::invalid_argument( + "[conjugate] conjugate must be called on complex input."); + } +} + void Copy::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); out.copy_shared_buffer(inputs[0]); diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index dd380f2c5..3752f6061 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -158,6 +158,12 @@ struct Cosh { }; }; +struct Conjugate { + complex64_t operator()(complex64_t x) { + return complex64_t{x.real, -x.imag}; + } +}; + struct Erf { template T operator()(T x) { diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index e9b52d58d..c1864ff14 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -94,6 +94,7 @@ instantiate_unary_float(tanh, Tanh) instantiate_unary_float(round, Round) instantiate_unary_all(abs, complex64, complex64_t, Abs) +instantiate_unary_all(conj, complex64, complex64_t, Conjugate) instantiate_unary_all(cos, complex64, complex64_t, Cos) instantiate_unary_all(cosh, complex64, complex64_t, Cosh) instantiate_unary_all(exp, complex64, complex64_t, Exp) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 06e9735a5..d989b2197 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -588,6 +588,17 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { } } +void Conjugate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == complex64) { + unary_op(inputs, out, "conj"); + } else { + throw std::invalid_argument( + "[conjugate] conjugate must be called on complex input."); + } +} + void Copy::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 0b6a8c4ab..63114d386 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -39,6 +39,7 @@ NO_GPU(Broadcast) NO_GPU(Ceil) NO_GPU_MULTI(Compiled) NO_GPU(Concatenate) +NO_GPU(Conjugate) NO_GPU(Convolution) NO_GPU(Copy) NO_GPU(Cos) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index f002b20eb..149ee7398 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -23,16 +23,17 @@ bool is_unary(const Primitive& p) { typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) || typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) || typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) || - typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) || - typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) || - typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) || - typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) || - typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) || - typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) || - typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) || - typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) || - typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) || - typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1)); + typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) || + typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) || + typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) || + typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) || + typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) || + typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) || + typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) || + typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) || + typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) || + typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) || + typeid(p) == typeid(Expm1)); } bool is_binary(const Primitive& p) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1903a34ad..7ab013f4e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4101,6 +4101,15 @@ array number_of_elements( {a})); } +array conjugate(const array& a, StreamOrDevice s /* = {} */) { + // Mirror NumPy's behaviour for real input + if (a.dtype() != complex64) { + return a; + } + return array( + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); +} + array bitwise_impl( const array& a, const array& b, diff --git a/mlx/ops.h b/mlx/ops.h index 454b22c0d..2df60362c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1239,6 +1239,8 @@ array number_of_elements( Dtype dtype = int32, StreamOrDevice s = {}); +array conjugate(const array& a, StreamOrDevice s = {}); + /** Bitwise and. */ array bitwise_and(const array& a, const array& b, StreamOrDevice s = {}); array operator&(const array& a, const array& b); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6f8a0511a..d9c0739f0 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -789,6 +789,14 @@ bool Concatenate::is_equivalent(const Primitive& other) const { return axis_ == c_other.axis_; } +std::pair, std::vector> Conjugate::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{conjugate(inputs[0], stream())}, axes}; +} + array conv_weight_backward_patches( const array& in, const array& wt, diff --git a/mlx/primitives.h b/mlx/primitives.h index 316ae13d5..868b5e7f5 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -620,6 +620,22 @@ class Concatenate : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Conjugate : public UnaryPrimitive { + public: + explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_PRINT(Conjugate) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + void eval(const std::vector& inputs, array& out); +}; + class Convolution : public UnaryPrimitive { public: explicit Convolution( diff --git a/python/src/array.cpp b/python/src/array.cpp index 05f4f323c..243328dbe 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1584,5 +1584,13 @@ void init_array(nb::module_& m) { "stream"_a = nb::none(), R"pbdoc( Extract a diagonal or construct a diagonal matrix. - )pbdoc"); + )pbdoc") + .def( + "conj", + [](const array& a, StreamOrDevice s) { + return mlx::core::conjugate(to_array(a), s); + }, + nb::kw_only(), + "stream"_a = nb::none(), + "See :func:`conj`."); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8c8caf06f..551d7ddda 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3027,6 +3027,40 @@ void init_ops(nb::module_& m) { inclusive (bool): The i-th element of the output includes the i-th element of the input. )pbdoc"); + m.def( + "conj", + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::conjugate(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the elementwise complex conjugate of the input. + Alias for `mx.conjugate`. + + Args: + a (array): Input array + )pbdoc"); + m.def( + "conjugate", + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::conjugate(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the elementwise complex conjugate of the input. + Alias for `mx.conj`. + + Args: + a (array): Input array + )pbdoc"); m.def( "convolve", [](const array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index b141e6e3c..ea84a5007 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase): "log1p", "floor", "ceil", + "conjugate", ] x = 0.5 @@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase): out_np = getattr(np, op)(a_np, b_np) self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + def test_conjugate(self): + shape = (3, 5, 7) + a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape) + a = a.astype(np.complex64) + ops = ["conjugate", "conj"] + for op in ops: + out_mlx = getattr(mx, op)(mx.array(a)) + out_np = getattr(np, op)(a) + self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + out_mlx = mx.array(a).conj() + out_np = a.conj() + self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + if __name__ == "__main__": unittest.main()