diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index e3c50e2ff..572b02a98 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -80,6 +80,7 @@ Operations greater_equal hadamard_transform identity + imag inner isfinite isclose @@ -125,6 +126,7 @@ Operations quantize quantized_matmul radians + real reciprocal remainder repeat diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index 058feddbc..824ea4724 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -295,6 +295,13 @@ struct Floor { } }; +struct Imag { + template + T operator()(T x) { + return std::imag(x); + } +}; + struct Log { template T operator()(T x) { @@ -337,6 +344,13 @@ struct Negative { } }; +struct Real { + template + T operator()(T x) { + return std::real(x); + } +}; + struct Round { template T operator()(T x) { diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index f015f9995..88d53b447 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -273,6 +273,10 @@ void Full::eval(const std::vector& inputs, array& out) { copy(in, out, ctype); } +void Imag::eval_cpu(const std::vector& inputs, array& out) { + unary_op(inputs[0], out, detail::Imag()); +} + void Log::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; @@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector& inputs, array& out) { } } +void Real::eval_cpu(const std::vector& inputs, array& out) { + unary_op(inputs[0], out, detail::Real()); +} + void Reshape::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h index f9e682777..944f5034a 100644 --- a/mlx/backend/common/unary.h +++ b/mlx/backend/common/unary.h @@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) { } } -template -void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) { +template +void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) { for (size_t i = 0; i < shape; i += 1) { out[i] = op(*a); a += stride; } } -template +template void unary_op(const array& a, array& out, Op op) { const T* a_ptr = a.data(); if (a.flags().contiguous) { set_unary_output_data(a, out); - T* dst = out.data(); + U* dst = out.data(); for (size_t i = 0; i < a.data_size(); ++i) { dst[i] = op(a_ptr[i]); } } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); - T* dst = out.data(); + U* dst = out.data(); size_t shape = a.ndim() > 0 ? a.shape(-1) : 1; size_t stride = a.ndim() > 0 ? a.strides(-1) : 1; if (a.ndim() <= 1) { diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index d0229f4cb..430ff65af 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -40,18 +40,21 @@ MTL::ComputePipelineState* get_arange_kernel( MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, + Dtype in_type, Dtype out_type, const std::string op) { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name, [&]() { + auto in_t = get_type_string(in_type); + auto out_t = get_type_string(out_type); std::ostringstream kernel_source; kernel_source << metal::utils() << metal::unary_ops() << metal::unary(); kernel_source << get_template_definition( - "v_" + lib_name, "unary_v", get_type_string(out_type), op); + "v_" + lib_name, "unary_v", in_t, out_t, op); kernel_source << get_template_definition( - "v2_" + lib_name, "unary_v2", get_type_string(out_type), op); + "v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source << get_template_definition( - "gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4); + "gn4_" + lib_name, "unary_g", in_t, out_t, op, 4); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 51d5121d8..2f861373d 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel( MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, + Dtype in_type, Dtype out_type, const std::string op); diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 8d404ae25..402d936c7 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,27 +1,27 @@ // Copyright © 2024 Apple Inc. -template +template [[kernel]] void unary_v( device const T* in, - device T* out, + device U* out, uint index [[thread_position_in_grid]]) { out[index] = Op()(in[index]); } -template +template [[kernel]] void unary_v2( device const T* in, - device T* out, + device U* out, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { size_t offset = index.x + grid_dim.x * size_t(index.y); out[offset] = Op()(in[offset]); } -template +template [[kernel]] void unary_g( device const T* in, - device T* out, + device U* out, constant const int* in_shape, constant const size_t* in_strides, device const int& ndim, diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index f301dce60..463708ab6 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -5,26 +5,30 @@ #include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -#define instantiate_unary_all(op, tname, type) \ - instantiate_kernel("v_" #op #tname, unary_v, type, op) \ - instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \ - instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) +#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \ + instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ + instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ + instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) + + +#define instantiate_unary_all_same(op, tname, type) \ + instantiate_unary_all(op, tname, tname, type, type) #define instantiate_unary_float(op) \ - instantiate_unary_all(op, float16, half) \ - instantiate_unary_all(op, float32, float) \ - instantiate_unary_all(op, bfloat16, bfloat16_t) + instantiate_unary_all_same(op, float16, half) \ + instantiate_unary_all_same(op, float32, float) \ + instantiate_unary_all_same(op, bfloat16, bfloat16_t) #define instantiate_unary_types(op) \ - instantiate_unary_all(op, bool_, bool) \ - instantiate_unary_all(op, uint8, uint8_t) \ - instantiate_unary_all(op, uint16, uint16_t) \ - instantiate_unary_all(op, uint32, uint32_t) \ - instantiate_unary_all(op, uint64, uint64_t) \ - instantiate_unary_all(op, int8, int8_t) \ - instantiate_unary_all(op, int16, int16_t) \ - instantiate_unary_all(op, int32, int32_t) \ - instantiate_unary_all(op, int64, int64_t) \ + instantiate_unary_all_same(op, bool_, bool) \ + instantiate_unary_all_same(op, uint8, uint8_t) \ + instantiate_unary_all_same(op, uint16, uint16_t) \ + instantiate_unary_all_same(op, uint32, uint32_t) \ + instantiate_unary_all_same(op, uint64, uint64_t) \ + instantiate_unary_all_same(op, int8, int8_t) \ + instantiate_unary_all_same(op, int16, int16_t) \ + instantiate_unary_all_same(op, int32, int32_t) \ + instantiate_unary_all_same(op, int64, int64_t) \ instantiate_unary_float(op) instantiate_unary_types(Abs) @@ -58,17 +62,19 @@ instantiate_unary_float(Tan) instantiate_unary_float(Tanh) instantiate_unary_float(Round) -instantiate_unary_all(Abs, complex64, complex64_t) -instantiate_unary_all(Conjugate, complex64, complex64_t) -instantiate_unary_all(Cos, complex64, complex64_t) -instantiate_unary_all(Cosh, complex64, complex64_t) -instantiate_unary_all(Exp, complex64, complex64_t) -instantiate_unary_all(Negative, complex64, complex64_t) -instantiate_unary_all(Sign, complex64, complex64_t) -instantiate_unary_all(Sin, complex64, complex64_t) -instantiate_unary_all(Sinh, complex64, complex64_t) -instantiate_unary_all(Tan, complex64, complex64_t) -instantiate_unary_all(Tanh, complex64, complex64_t) -instantiate_unary_all(Round, complex64, complex64_t) +instantiate_unary_all_same(Abs, complex64, complex64_t) +instantiate_unary_all_same(Conjugate, complex64, complex64_t) +instantiate_unary_all_same(Cos, complex64, complex64_t) +instantiate_unary_all_same(Cosh, complex64, complex64_t) +instantiate_unary_all_same(Exp, complex64, complex64_t) +instantiate_unary_all_same(Negative, complex64, complex64_t) +instantiate_unary_all_same(Sign, complex64, complex64_t) +instantiate_unary_all_same(Sin, complex64, complex64_t) +instantiate_unary_all_same(Sinh, complex64, complex64_t) +instantiate_unary_all_same(Tan, complex64, complex64_t) +instantiate_unary_all_same(Tanh, complex64, complex64_t) +instantiate_unary_all_same(Round, complex64, complex64_t) +instantiate_unary_all(Real, complex64, float32, complex64_t, float) +instantiate_unary_all(Imag, complex64, float32, complex64_t, float) -instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on +instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 857185a9d..ebb5573e4 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -238,6 +238,13 @@ struct Floor { }; }; +struct Imag { + template + T operator()(T x) { + return x.imag; + }; +}; + struct Log { template T operator()(T x) { @@ -280,6 +287,13 @@ struct Negative { }; }; +struct Real { + template + T operator()(T x) { + return x.real; + }; +}; + struct Round { template T operator()(T x) { diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 3947ea419..006e2ae46 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -16,6 +16,7 @@ MTL::ComputePipelineState* get_unary_kernel( metal::Device& d, const std::string& kernel_name, Dtype, + Dtype, const std::string) { return d.get_kernel(kernel_name); } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index a3903b89c..acb469f15 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -44,8 +44,8 @@ void unary_op_gpu_inplace( } else { kernel_name = (work_per_thread == 4 ? "gn4" : "g"); } - kernel_name += "_" + op + type_to_name(out); - auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op); + kernel_name += "_" + op + type_to_name(in) + type_to_name(out); + auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides()) : MTL::Size(nthreads, 1, 1); @@ -124,11 +124,13 @@ UNARY_GPU(Erf) UNARY_GPU(ErfInv) UNARY_GPU(Exp) UNARY_GPU(Expm1) +UNARY_GPU(Imag) UNARY_GPU(Log1p) UNARY_GPU(LogicalNot) UNARY_GPU(Floor) UNARY_GPU(Ceil) UNARY_GPU(Negative) +UNARY_GPU(Real) UNARY_GPU(Sigmoid) UNARY_GPU(Sign) UNARY_GPU(Sin) diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ff60e4d22..fd15c403b 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -62,6 +62,7 @@ NO_CPU(GatherQMM) NO_CPU(Greater) NO_CPU(GreaterEqual) NO_CPU(Hadamard) +NO_CPU(Imag) NO_CPU(Less) NO_CPU(LessEqual) NO_CPU(Load) @@ -83,6 +84,7 @@ NO_CPU(Power) NO_CPU_MULTI(QRF) NO_CPU(QuantizedMatmul) NO_CPU(RandomBits) +NO_CPU(Real) NO_CPU(Reduce) NO_CPU(Reshape) NO_CPU(Round) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 544a2c6f2..5270a6fdd 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -64,6 +64,7 @@ NO_GPU(GatherQMM) NO_GPU(Greater) NO_GPU(GreaterEqual) NO_GPU(Hadamard) +NO_GPU(Imag) NO_GPU(Less) NO_GPU(LessEqual) NO_GPU(Load) @@ -85,6 +86,7 @@ NO_GPU(Power) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) +NO_GPU(Real) NO_GPU(Reduce) NO_GPU(Reshape) NO_GPU(Round) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index f5082010c..9b52baa91 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -33,7 +33,8 @@ bool is_unary(const Primitive& p) { typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) || - typeid(p) == typeid(Expm1)); + typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) || + typeid(p) == typeid(Imag)); } bool is_binary(const Primitive& p) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c415e4504..680240439 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4657,4 +4657,18 @@ array roll( return roll(a, std::vector{total_shift}, std::vector{axis}, s); } +array real(const array& a, StreamOrDevice s /* = {} */) { + if (!issubdtype(a.dtype(), complexfloating)) { + return a; + } + return array(a.shape(), float32, std::make_shared(to_stream(s)), {a}); +} + +array imag(const array& a, StreamOrDevice s /* = {} */) { + if (!issubdtype(a.dtype(), complexfloating)) { + return zeros_like(a); + } + return array(a.shape(), float32, std::make_shared(to_stream(s)), {a}); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 32f77514f..7b8e9327e 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1480,6 +1480,12 @@ array roll( const std::vector& axes, StreamOrDevice s = {}); +/* The real part of a complex array. */ +array real(const array& a, StreamOrDevice s = {}); + +/* The imaginary part of a complex array. */ +array imag(const array& a, StreamOrDevice s = {}); + /** @} */ } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ddcc7d938..8aa0392b7 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1797,6 +1797,36 @@ std::vector GreaterEqual::jvp( return {zeros(shape, bool_, stream())}; } +std::vector Imag::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {multiply( + array(complex64_t{0.0f, -1.0f}, primals[0].dtype()), + cotangents[0], + stream())}; +} + +std::vector Imag::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {imag(tangents[0], stream())}; +} + +std::pair, std::vector> Imag::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{imag(inputs[0], stream())}, axes}; +} + std::pair, std::vector> Less::vmap( const std::vector& inputs, const std::vector& axes) { @@ -2633,6 +2663,33 @@ bool RandomBits::is_equivalent(const Primitive& other) const { return shape_ == r_other.shape_; } +std::vector Real::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {astype(cotangents[0], primals[0].dtype(), stream())}; +} + +std::vector Real::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {real(tangents[0], stream())}; +} + +std::pair, std::vector> Real::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{real(inputs[0], stream())}, axes}; +} + std::pair, std::vector> Reshape::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 810eb5096..4bec71445 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1106,6 +1106,20 @@ class Hadamard : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Imag : public UnaryPrimitive { + public: + explicit Imag(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Imag) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + class Less : public UnaryPrimitive { public: explicit Less(Stream stream) : UnaryPrimitive(stream) {} @@ -1561,6 +1575,20 @@ class RandomBits : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Real : public UnaryPrimitive { + public: + explicit Real(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Real) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + class Reshape : public UnaryPrimitive { public: explicit Reshape(Stream stream, const std::vector& shape) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4ffa21dd9..a17c9ea0b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4842,4 +4842,42 @@ void init_ops(nb::module_& m) { axis (int or tuple(int), optional): The axis or axes along which to roll the elements. )pbdoc"); + m.def( + "real", + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::real(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def real(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Returns the real part of a complex array. + + Args: + a (array): Input array. + + Returns: + array: The real part of ``a``. + )pbdoc"); + m.def( + "imag", + [](const ScalarOrArray& a, StreamOrDevice s) { + return mlx::core::imag(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def imag(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Returns the imaginary part of a complex array. + + Args: + a (array): Input array. + + Returns: + array: The imaginary part of ``a``. + )pbdoc"); } diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 1e99c3825..5f3b62a8a 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -590,6 +590,21 @@ class TestAutograd(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out1[0], out2[0])) self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0])) + def test_complex_vjps(self): + def fun(x): + return (2.0 * mx.real(x)).sum() + + x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j]) + dfdx = mx.grad(fun)(x) + self.assertTrue(mx.allclose(dfdx, 2 * mx.ones_like(x))) + + def fun(x): + return (2.0 * mx.imag(x)).sum() + + x = mx.array([0.0 + 1j, 1.0 + 0.0j, 0.5 + 0.5j]) + dfdx = mx.grad(fun)(x) + self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 311858670..34b2d66bf 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2680,6 +2680,21 @@ class TestOps(mlx_tests.MLXTestCase): y2 = mx.roll(x, s, a) self.assertTrue(mx.array_equal(y1, y2).item()) + def test_real_imag(self): + x = mx.random.uniform(shape=(4, 4)) + out = mx.real(x) + self.assertTrue(mx.array_equal(x, out)) + + out = mx.imag(x) + self.assertTrue(mx.array_equal(mx.zeros_like(x), out)) + + y = mx.random.uniform(shape=(4, 4)) + z = x + 1j * y + self.assertEqual(mx.real(z).dtype, mx.float32) + self.assertTrue(mx.array_equal(mx.real(z), x)) + self.assertEqual(mx.imag(z).dtype, mx.float32) + self.assertTrue(mx.array_equal(mx.imag(z), y)) + if __name__ == "__main__": unittest.main()