mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 18:51:14 +08:00
Add conjugate operator (#1100)
* cpu and gpu impl * add mx.conj and array.conj() --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
8bd6bfa4b5
commit
2e158cf6d0
@ -38,6 +38,8 @@ Operations
|
|||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
|
conj
|
||||||
|
conjugate
|
||||||
convolve
|
convolve
|
||||||
conv1d
|
conv1d
|
||||||
conv2d
|
conv2d
|
||||||
|
@ -36,6 +36,7 @@ DEFAULT(BlockSparseMM)
|
|||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
|
DEFAULT(Conjugate)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT_MULTI(CustomVJP)
|
DEFAULT_MULTI(CustomVJP)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
|
@ -47,6 +47,7 @@ DEFAULT(BlockSparseMM)
|
|||||||
DEFAULT_MULTI(DivMod)
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
|
DEFAULT(Conjugate)
|
||||||
DEFAULT(Convolution)
|
DEFAULT(Convolution)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT(Cos)
|
DEFAULT(Cos)
|
||||||
|
@ -209,6 +209,12 @@ struct Ceil {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Conjugate {
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return std::conj(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Cos {
|
struct Cos {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
|
@ -203,6 +203,17 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == complex64) {
|
||||||
|
unary_fp(in, out, detail::Conjugate());
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[conjugate] conjugate must be called on complex input.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.copy_shared_buffer(inputs[0]);
|
out.copy_shared_buffer(inputs[0]);
|
||||||
|
@ -158,6 +158,12 @@ struct Cosh {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Conjugate {
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return complex64_t{x.real, -x.imag};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct Erf {
|
struct Erf {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
|
@ -94,6 +94,7 @@ instantiate_unary_float(tanh, Tanh)
|
|||||||
instantiate_unary_float(round, Round)
|
instantiate_unary_float(round, Round)
|
||||||
|
|
||||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||||
|
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
|
||||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||||
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
||||||
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
||||||
|
@ -588,6 +588,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == complex64) {
|
||||||
|
unary_op(inputs, out, "conj");
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[conjugate] conjugate must be called on complex input.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,7 @@ NO_GPU(Broadcast)
|
|||||||
NO_GPU(Ceil)
|
NO_GPU(Ceil)
|
||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
NO_GPU(Concatenate)
|
NO_GPU(Concatenate)
|
||||||
|
NO_GPU(Conjugate)
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
NO_GPU(Copy)
|
NO_GPU(Copy)
|
||||||
NO_GPU(Cos)
|
NO_GPU(Cos)
|
||||||
|
@ -23,16 +23,17 @@ bool is_unary(const Primitive& p) {
|
|||||||
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
||||||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
||||||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
||||||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
|
typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) ||
|
||||||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
|
typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) ||
|
||||||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
|
typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) ||
|
||||||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
|
typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) ||
|
||||||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
|
typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) ||
|
||||||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
|
typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) ||
|
||||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) ||
|
||||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
|
||||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
|
||||||
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
|
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
|
||||||
|
typeid(p) == typeid(Expm1));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_binary(const Primitive& p) {
|
bool is_binary(const Primitive& p) {
|
||||||
|
@ -4101,6 +4101,15 @@ array number_of_elements(
|
|||||||
{a}));
|
{a}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array conjugate(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
// Mirror NumPy's behaviour for real input
|
||||||
|
if (a.dtype() != complex64) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
return array(
|
||||||
|
a.shape(), a.dtype(), std::make_shared<Conjugate>(to_stream(s)), {a});
|
||||||
|
}
|
||||||
|
|
||||||
array bitwise_impl(
|
array bitwise_impl(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
|
@ -1239,6 +1239,8 @@ array number_of_elements(
|
|||||||
Dtype dtype = int32,
|
Dtype dtype = int32,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array conjugate(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Bitwise and. */
|
/** Bitwise and. */
|
||||||
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
|
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
array operator&(const array& a, const array& b);
|
array operator&(const array& a, const array& b);
|
||||||
|
@ -789,6 +789,14 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
|
|||||||
return axis_ == c_other.axis_;
|
return axis_ == c_other.axis_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(axes.size() == 1);
|
||||||
|
return {{conjugate(inputs[0], stream())}, axes};
|
||||||
|
}
|
||||||
|
|
||||||
array conv_weight_backward_patches(
|
array conv_weight_backward_patches(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
|
@ -620,6 +620,22 @@ class Concatenate : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Conjugate : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_PRINT(Conjugate)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class Convolution : public UnaryPrimitive {
|
class Convolution : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Convolution(
|
explicit Convolution(
|
||||||
|
@ -1584,5 +1584,13 @@ void init_array(nb::module_& m) {
|
|||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Extract a diagonal or construct a diagonal matrix.
|
Extract a diagonal or construct a diagonal matrix.
|
||||||
)pbdoc");
|
)pbdoc")
|
||||||
|
.def(
|
||||||
|
"conj",
|
||||||
|
[](const array& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::conjugate(to_array(a), s);
|
||||||
|
},
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
"See :func:`conj`.");
|
||||||
}
|
}
|
||||||
|
@ -3027,6 +3027,40 @@ void init_ops(nb::module_& m) {
|
|||||||
inclusive (bool): The i-th element of the output includes the i-th
|
inclusive (bool): The i-th element of the output includes the i-th
|
||||||
element of the input.
|
element of the input.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"conj",
|
||||||
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::conjugate(to_array(a), s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Return the elementwise complex conjugate of the input.
|
||||||
|
Alias for `mx.conjugate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"conjugate",
|
||||||
|
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||||
|
return mlx::core::conjugate(to_array(a), s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Return the elementwise complex conjugate of the input.
|
||||||
|
Alias for `mx.conj`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"convolve",
|
"convolve",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
|
@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
"log1p",
|
"log1p",
|
||||||
"floor",
|
"floor",
|
||||||
"ceil",
|
"ceil",
|
||||||
|
"conjugate",
|
||||||
]
|
]
|
||||||
|
|
||||||
x = 0.5
|
x = 0.5
|
||||||
@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out_np = getattr(np, op)(a_np, b_np)
|
out_np = getattr(np, op)(a_np, b_np)
|
||||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||||
|
|
||||||
|
def test_conjugate(self):
|
||||||
|
shape = (3, 5, 7)
|
||||||
|
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||||
|
a = a.astype(np.complex64)
|
||||||
|
ops = ["conjugate", "conj"]
|
||||||
|
for op in ops:
|
||||||
|
out_mlx = getattr(mx, op)(mx.array(a))
|
||||||
|
out_np = getattr(np, op)(a)
|
||||||
|
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||||
|
out_mlx = mx.array(a).conj()
|
||||||
|
out_np = a.conj()
|
||||||
|
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user