mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add conjugate operator (#1100)
* cpu and gpu impl * add mx.conj and array.conj() --------- Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
8bd6bfa4b5
commit
2e158cf6d0
@ -38,6 +38,8 @@ Operations
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
|
@ -36,6 +36,7 @@ DEFAULT(BlockSparseMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
|
@ -47,6 +47,7 @@ DEFAULT(BlockSparseMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
|
@ -209,6 +209,12 @@ struct Ceil {
|
||||
};
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::conj(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -203,6 +203,17 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_fp(in, out, detail::Conjugate());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
|
@ -158,6 +158,12 @@ struct Cosh {
|
||||
};
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return complex64_t{x.real, -x.imag};
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -94,6 +94,7 @@ instantiate_unary_float(tanh, Tanh)
|
||||
instantiate_unary_float(round, Round)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
||||
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
||||
|
@ -588,6 +588,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_op(inputs, out, "conj");
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Conjugate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
NO_GPU(Cos)
|
||||
|
@ -23,16 +23,17 @@ bool is_unary(const Primitive& p) {
|
||||
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
||||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
||||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
||||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
|
||||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
|
||||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
|
||||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
|
||||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
|
||||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
|
||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||
typeid(p) == typeid(Tanh) || typeid(p) == typeid(Expm1));
|
||||
typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) ||
|
||||
typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) ||
|
||||
typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) ||
|
||||
typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) ||
|
||||
typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) ||
|
||||
typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) ||
|
||||
typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) ||
|
||||
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
|
||||
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
|
||||
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
|
||||
typeid(p) == typeid(Expm1));
|
||||
}
|
||||
|
||||
bool is_binary(const Primitive& p) {
|
||||
|
@ -4101,6 +4101,15 @@ array number_of_elements(
|
||||
{a}));
|
||||
}
|
||||
|
||||
array conjugate(const array& a, StreamOrDevice s /* = {} */) {
|
||||
// Mirror NumPy's behaviour for real input
|
||||
if (a.dtype() != complex64) {
|
||||
return a;
|
||||
}
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_shared<Conjugate>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
array bitwise_impl(
|
||||
const array& a,
|
||||
const array& b,
|
||||
|
@ -1239,6 +1239,8 @@ array number_of_elements(
|
||||
Dtype dtype = int32,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array conjugate(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Bitwise and. */
|
||||
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator&(const array& a, const array& b);
|
||||
|
@ -789,6 +789,14 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
|
||||
return axis_ == c_other.axis_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{conjugate(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
array conv_weight_backward_patches(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
|
@ -620,6 +620,22 @@ class Concatenate : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Conjugate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Conjugate)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Convolution : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Convolution(
|
||||
|
@ -1584,5 +1584,13 @@ void init_array(nb::module_& m) {
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Extract a diagonal or construct a diagonal matrix.
|
||||
)pbdoc");
|
||||
)pbdoc")
|
||||
.def(
|
||||
"conj",
|
||||
[](const array& a, StreamOrDevice s) {
|
||||
return mlx::core::conjugate(to_array(a), s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
"See :func:`conj`.");
|
||||
}
|
||||
|
@ -3027,6 +3027,40 @@ void init_ops(nb::module_& m) {
|
||||
inclusive (bool): The i-th element of the output includes the i-th
|
||||
element of the input.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conj",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
return mlx::core::conjugate(to_array(a), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return the elementwise complex conjugate of the input.
|
||||
Alias for `mx.conjugate`.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conjugate",
|
||||
[](const ScalarOrArray& a, StreamOrDevice s) {
|
||||
return mlx::core::conjugate(to_array(a), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return the elementwise complex conjugate of the input.
|
||||
Alias for `mx.conj`.
|
||||
|
||||
Args:
|
||||
a (array): Input array
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"convolve",
|
||||
[](const array& a,
|
||||
|
@ -1245,6 +1245,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
"log1p",
|
||||
"floor",
|
||||
"ceil",
|
||||
"conjugate",
|
||||
]
|
||||
|
||||
x = 0.5
|
||||
@ -2258,6 +2259,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
def test_conjugate(self):
|
||||
shape = (3, 5, 7)
|
||||
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||
a = a.astype(np.complex64)
|
||||
ops = ["conjugate", "conj"]
|
||||
for op in ops:
|
||||
out_mlx = getattr(mx, op)(mx.array(a))
|
||||
out_np = getattr(np, op)(a)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
out_mlx = mx.array(a).conj()
|
||||
out_np = a.conj()
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user