Add bitwise ops (#1037)

* bitwise ops

* fix tests
This commit is contained in:
Awni Hannun 2024-04-26 22:03:42 -07:00 committed by GitHub
parent 67d1894759
commit 86f495985b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 568 additions and 58 deletions

View File

@ -28,8 +28,11 @@ Operations
atleast_1d
atleast_2d
atleast_3d
broadcast_to
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
broadcast_to
ceil
clip
concatenate
@ -69,6 +72,7 @@ Operations
isnan
isneginf
isposinf
left_shift
less
less_equal
linspace
@ -105,6 +109,7 @@ Operations
reciprocal
repeat
reshape
right_shift
round
rsqrt
save

View File

@ -236,4 +236,61 @@ void Subtract::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
} // namespace mlx::core

View File

@ -606,4 +606,39 @@ struct Select {
}
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
} // namespace mlx::core::detail

View File

@ -229,3 +229,38 @@ struct LogicalOr {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};

View File

@ -184,13 +184,7 @@ template <typename T, typename U, typename Op>
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
@ -199,6 +193,15 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
@ -241,3 +244,13 @@ instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
// Bitwise ops only need integer types and bool (except for l/r shift)
instantiate_binary_integer(bitwise_and, BitwiseAnd)
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
instantiate_binary_integer(bitwise_or, BitwiseOr)
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift)

View File

@ -533,6 +533,26 @@ void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op(inputs, out, "right_shift");
break;
}
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

View File

@ -31,6 +31,7 @@ NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(Broadcast)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)

View File

@ -45,7 +45,7 @@ bool is_binary(const Primitive& p) {
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
typeid(p) == typeid(Subtract));
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary));
}
bool is_ternary(const Primitive& p) {

View File

@ -3944,4 +3944,77 @@ array number_of_elements(
{a}));
}
array bitwise_impl(
const array& a,
const array& b,
BitwiseBinary::Op op,
const std::string& op_name,
const StreamOrDevice& s) {
auto out_type = promote_types(a.dtype(), b.dtype());
if (!(issubdtype(out_type, integer) || out_type == bool_)) {
std::ostringstream msg;
msg << "[" << op_name
<< "] Only allowed on integer or boolean types "
"but got types "
<< a.dtype() << " and " << b.dtype() << ".";
throw std::runtime_error(msg.str());
}
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
return array(
a.shape(),
out_type,
std::make_shared<BitwiseBinary>(to_stream(s), op),
std::move(inputs));
}
array bitwise_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return bitwise_impl(a, b, BitwiseBinary::Op::And, "bitwise_and", s);
}
array operator&(const array& a, const array& b) {
return bitwise_and(a, b);
}
array bitwise_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return bitwise_impl(a, b, BitwiseBinary::Op::Or, "bitwise_or", s);
}
array operator|(const array& a, const array& b) {
return bitwise_or(a, b);
}
array bitwise_xor(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return bitwise_impl(a, b, BitwiseBinary::Op::Xor, "bitwise_xor", s);
}
array operator^(const array& a, const array& b) {
return bitwise_xor(a, b);
}
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8
auto t = promote_types(result_type(a, b), uint8);
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::LeftShift,
"left_shift",
s);
}
array operator<<(const array& a, const array& b) {
return left_shift(a, b);
}
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8
auto t = promote_types(result_type(a, b), uint8);
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::RightShift,
"right_shift",
s);
}
array operator>>(const array& a, const array& b) {
return right_shift(a, b);
}
} // namespace mlx::core

View File

@ -1027,17 +1027,6 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
/** Raise elements of a to the power of b element-wise */
array power(const array& a, const array& b, StreamOrDevice s = {});
inline array operator^(const array& a, const array& b) {
return power(a, b);
}
template <typename T>
array operator^(T a, const array& b) {
return power(array(a), b);
}
template <typename T>
array operator^(const array& a, T b) {
return power(a, array(b));
}
/** Cumulative sum of an array. */
array cumsum(
@ -1239,6 +1228,26 @@ array number_of_elements(
Dtype dtype = int32,
StreamOrDevice s = {});
/** Bitwise and. */
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
array operator&(const array& a, const array& b);
/** Bitwise inclusive or. */
array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
array operator|(const array& a, const array& b);
/** Bitwise exclusive or. */
array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
array operator^(const array& a, const array& b);
/** Shift bits to the left. */
array left_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator<<(const array& a, const array& b);
/** Shift bits to the right. */
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator>>(const array& a, const array& b);
/** @} */
} // namespace mlx::core

View File

@ -560,6 +560,44 @@ bool AsStrided::is_equivalent(const Primitive& other) const {
offset_ == a_other.offset_;
}
bool BitwiseBinary::is_equivalent(const Primitive& other) const {
const BitwiseBinary& a_other = static_cast<const BitwiseBinary&>(other);
return op_ == a_other.op_;
}
void BitwiseBinary::print(std::ostream& os) {
switch (op_) {
case BitwiseBinary::And:
os << "BitwiseAnd";
break;
case BitwiseBinary::Or:
os << "BitwiseOr";
break;
case BitwiseBinary::Xor:
os << "BitwiseXor";
break;
case BitwiseBinary::LeftShift:
os << "LeftShift";
break;
case BitwiseBinary::RightShift:
os << "RightShift";
break;
}
}
std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {
{array(
a.shape(),
a.dtype(),
std::make_shared<BitwiseBinary>(stream(), op_),
{a, b})},
{to_ax}};
}
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -443,6 +443,25 @@ class AsStrided : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class BitwiseBinary : public UnaryPrimitive {
public:
enum Op { And, Or, Xor, LeftShift, RightShift };
explicit BitwiseBinary(Stream stream, Op op)
: UnaryPrimitive(stream), op_(op){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override;
DEFINE_INPUT_OUTPUT_SHAPE()
private:
Op op_;
};
class BlockMaskedMM : public UnaryPrimitive {
public:
explicit BlockMaskedMM(Stream stream, int block_size)

View File

@ -1017,11 +1017,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
return logical_and(a, b);
return bitwise_and(a, b);
},
"other"_a)
.def(
@ -1036,11 +1032,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
a.overwrite_descriptor(logical_and(a, b));
a.overwrite_descriptor(bitwise_and(a, b));
return a;
},
"other"_a,
@ -1057,11 +1049,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
return logical_or(a, b);
return bitwise_or(a, b);
},
"other"_a)
.def(
@ -1076,11 +1064,71 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
a.overwrite_descriptor(bitwise_or(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def(
"__lshift__",
[](const array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("left shift", v);
}
a.overwrite_descriptor(logical_or(a, b));
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with left shift.");
}
return left_shift(a, b);
},
"other"_a)
.def(
"__ilshift__",
[](array& a, const ScalarOrArray v) -> array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace left shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with or left shift.");
}
a.overwrite_descriptor(left_shift(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def(
"__rshift__",
[](const array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("right shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with right shift.");
}
return right_shift(a, b);
},
"other"_a)
.def(
"__irshift__",
[](array& a, const ScalarOrArray v) -> array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace right shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with or right shift.");
}
a.overwrite_descriptor(right_shift(a, b));
return a;
},
"other"_a,

View File

@ -3897,4 +3897,132 @@ void init_ops(nb::module_& m) {
&issubdtype),
""_a,
""_a);
m.def(
"bitwise_and",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_and(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise and.
Take the bitwise and of two arrays with numpy-style broadcasting
semantics. Either or both input arrays can also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise and ``a & b``.
)pbdoc");
m.def(
"bitwise_or",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_or(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise or.
Take the bitwise or of two arrays with numpy-style broadcasting
semantics. Either or both input arrays can also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise or``a | b``.
)pbdoc");
m.def(
"bitwise_xor",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_xor(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise xor.
Take the bitwise exclusive or of two arrays with numpy-style
broadcasting semantics. Either or both input arrays can also be
scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise xor ``a ^ b``.
)pbdoc");
m.def(
"left_shift",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return left_shift(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise left shift.
Shift the bits of the first input to the left by the second using
numpy-style broadcasting semantics. Either or both input arrays can
also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise left shift ``a << b``.
)pbdoc");
m.def(
"right_shift",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return right_shift(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise right shift.
Shift the bits of the first input to the right by the second using
numpy-style broadcasting semantics. Either or both input arrays can
also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise right shift ``a >> b``.
)pbdoc");
}

View File

@ -2177,6 +2177,38 @@ class TestOps(mlx_tests.MLXTestCase):
f"mx and np don't aggree on {a}, {b}",
)
def test_bitwise_ops(self):
types = [
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.int8,
mx.int16,
mx.int32,
mx.int64,
]
a = mx.random.randint(0, 4096, (1000,))
b = mx.random.randint(0, 4096, (1000,))
for op in ["bitwise_and", "bitwise_or", "bitwise_xor"]:
for t in types:
a_mlx = a.astype(t)
b_mlx = b.astype(t)
a_np = np.array(a_mlx)
b_np = np.array(b_mlx)
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
for op in ["left_shift", "right_shift"]:
for t in types:
a_mlx = a.astype(t)
b_mlx = mx.random.randint(0, t.size, (1000,)).astype(t)
a_np = np.array(a_mlx)
b_np = np.array(b_mlx)
out_mlx = getattr(mx, op)(a_mlx, b_mlx)
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
if __name__ == "__main__":
unittest.main()

View File

@ -596,7 +596,7 @@ TEST_CASE("test op vjps") {
// Test power
{
auto fun = [](std::vector<array> inputs) {
return std::vector<array>{inputs[0] ^ inputs[1]};
return std::vector<array>{power(inputs[0], inputs[1])};
};
auto out = vjp(fun, {array(4.0f), array(3.0f)}, {array(1.0f)}).second;
CHECK_EQ(out[0].item<float>(), 48.0f);

View File

@ -2308,29 +2308,26 @@ TEST_CASE("test pad") {
TEST_CASE("test power") {
CHECK_EQ(power(array(1), array(2)).item<int>(), 1);
CHECK_EQ((array(1) ^ 2).item<int>(), 1);
CHECK_EQ((1 ^ array(2)).item<int>(), 1);
CHECK_EQ((array(-1) ^ 2).item<int>(), 1);
CHECK_EQ((array(-1) ^ 3).item<int>(), -1);
CHECK_EQ((power(array(-1), array(2))).item<int>(), 1);
CHECK_EQ((power(array(-1), array(3))).item<int>(), -1);
// TODO Throws but exception not caught from calling thread
// CHECK_THROWS((x^-1).item<int>());
CHECK_EQ((array(true) ^ array(false)).item<bool>(), true);
CHECK_EQ((array(false) ^ array(false)).item<bool>(), true);
CHECK_EQ((array(true) ^ array(true)).item<bool>(), true);
CHECK_EQ((array(false) ^ array(true)).item<bool>(), false);
CHECK_EQ((power(array(true), array(false))).item<bool>(), true);
CHECK_EQ((power(array(false), array(false))).item<bool>(), true);
CHECK_EQ((power(array(true), array(true))).item<bool>(), true);
CHECK_EQ((power(array(false), array(true))).item<bool>(), false);
auto x = array(2.0f);
CHECK_EQ((x ^ 0.5).item<float>(), doctest::Approx(std::pow(2.0f, 0.5f)));
CHECK_EQ((x ^ 2.0f).item<float>(), 4.0f);
CHECK_EQ(
(power(x, array(0.5))).item<float>(),
doctest::Approx(std::pow(2.0f, 0.5f)));
CHECK_EQ(power(x, array(2.0f)).item<float>(), 4.0f);
CHECK(std::isnan((array(-1.0f) ^ 0.5).item<float>()));
CHECK(std::isnan((power(array(-1.0f), array(0.5))).item<float>()));
auto a = complex64_t{0.5, 0.5};
auto b = complex64_t{0.5, 0.5};
auto expected = std::pow(a, b);
auto out = (array(a) ^ array(b)).item<complex64_t>();
auto out = (power(array(a), array(b))).item<complex64_t>();
CHECK(abs(out.real() - expected.real()) < 1e-7);
CHECK(abs(out.imag() - expected.imag()) < 1e-7);
}