diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 7795512a0..7bae81069 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -28,8 +28,11 @@ Operations atleast_1d atleast_2d atleast_3d - broadcast_to + bitwise_and + bitwise_or + bitwise_xor block_masked_mm + broadcast_to ceil clip concatenate @@ -69,6 +72,7 @@ Operations isnan isneginf isposinf + left_shift less less_equal linspace @@ -105,6 +109,7 @@ Operations reciprocal repeat reshape + right_shift round rsqrt save diff --git a/mlx/backend/common/binary.cpp b/mlx/backend/common/binary.cpp index 810062dfd..060d7565c 100644 --- a/mlx/backend/common/binary.cpp +++ b/mlx/backend/common/binary.cpp @@ -236,4 +236,61 @@ void Subtract::eval(const std::vector& inputs, array& out) { binary(a, b, out, detail::Subtract()); } +void BitwiseBinary::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto dispatch_type = [&a, &b, &out](auto op) { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, op); + case uint8: + binary_op(a, b, out, op); + break; + case uint16: + binary_op(a, b, out, op); + break; + case uint32: + binary_op(a, b, out, op); + break; + case uint64: + binary_op(a, b, out, op); + break; + case int8: + binary_op(a, b, out, op); + break; + case int16: + binary_op(a, b, out, op); + break; + case int32: + binary_op(a, b, out, op); + break; + case int64: + binary_op(a, b, out, op); + break; + default: + throw std::runtime_error( + "[BitwiseBinary::eval_cpu] Type not supported"); + break; + } + }; + switch (op_) { + case BitwiseBinary::And: + dispatch_type(detail::BitwiseAnd()); + break; + case BitwiseBinary::Or: + dispatch_type(detail::BitwiseOr()); + break; + case BitwiseBinary::Xor: + dispatch_type(detail::BitwiseXor()); + break; + case BitwiseBinary::LeftShift: + dispatch_type(detail::LeftShift()); + break; + case BitwiseBinary::RightShift: + dispatch_type(detail::RightShift()); + break; + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index 0aff1de37..ae2e2d225 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -606,4 +606,39 @@ struct Select { } }; +struct BitwiseAnd { + template + T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + T operator()(T x, T y) { + return x >> y; + }; +}; + } // namespace mlx::core::detail diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 006f2ff0e..e9606d4f5 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -229,3 +229,38 @@ struct LogicalOr { return x || y; }; }; + +struct BitwiseAnd { + template + T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + T operator()(T x, T y) { + return x >> y; + }; +}; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index eff687231..7674a13f1 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -184,13 +184,7 @@ template instantiate_binary_g("g" #name #tname, itype, otype, op) \ instantiate_binary_g_nd("g" #name #tname, itype, otype, op) -#define instantiate_binary_float(name, op) \ - instantiate_binary_all(name, float16, half, half, op) \ - instantiate_binary_all(name, float32, float, float, op) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) - -#define instantiate_binary_types(name, op) \ - instantiate_binary_all(name, bool_, bool, bool, op) \ +#define instantiate_binary_integer(name, op) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \ @@ -199,6 +193,15 @@ template instantiate_binary_all(name, int16, int16_t, int16_t, op) \ instantiate_binary_all(name, int32, int32_t, int32_t, op) \ instantiate_binary_all(name, int64, int64_t, int64_t, op) \ + +#define instantiate_binary_float(name, op) \ + instantiate_binary_all(name, float16, half, half, op) \ + instantiate_binary_all(name, float32, float, float, op) \ + instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) + +#define instantiate_binary_types(name, op) \ + instantiate_binary_all(name, bool_, bool, bool, op) \ + instantiate_binary_integer(name, op) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ instantiate_binary_float(name, op) @@ -241,3 +244,13 @@ instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) instantiate_binary_all(lor, bool_, bool, bool, LogicalOr) instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) + +// Bitwise ops only need integer types and bool (except for l/r shift) +instantiate_binary_integer(bitwise_and, BitwiseAnd) +instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd) +instantiate_binary_integer(bitwise_or, BitwiseOr) +instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr) +instantiate_binary_integer(bitwise_xor, BitwiseXor) +instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor) +instantiate_binary_integer(left_shift, LeftShift) +instantiate_binary_integer(right_shift, RightShift) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 9137ff12f..364132eba 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -533,6 +533,26 @@ void AsStrided::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + switch (op_) { + case BitwiseBinary::And: + binary_op(inputs, out, "bitwise_and"); + break; + case BitwiseBinary::Or: + binary_op(inputs, out, "bitwise_or"); + break; + case BitwiseBinary::Xor: + binary_op(inputs, out, "bitwise_xor"); + break; + case BitwiseBinary::LeftShift: + binary_op(inputs, out, "left_shift"); + break; + case BitwiseBinary::RightShift: + binary_op(inputs, out, "right_shift"); + break; + } +} + void Broadcast::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 4891415a3..f9248cf12 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -31,6 +31,7 @@ NO_GPU(ArgReduce) NO_GPU(ArgSort) NO_GPU(AsType) NO_GPU(AsStrided) +NO_GPU(BitwiseBinary) NO_GPU(Broadcast) NO_GPU(Ceil) NO_GPU_MULTI(Compiled) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 456f658a1..55aafffaf 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -45,7 +45,7 @@ bool is_binary(const Primitive& p) { typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) || typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) || typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) || - typeid(p) == typeid(Subtract)); + typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary)); } bool is_ternary(const Primitive& p) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index b462547ec..1ec08fa56 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3944,4 +3944,77 @@ array number_of_elements( {a})); } +array bitwise_impl( + const array& a, + const array& b, + BitwiseBinary::Op op, + const std::string& op_name, + const StreamOrDevice& s) { + auto out_type = promote_types(a.dtype(), b.dtype()); + if (!(issubdtype(out_type, integer) || out_type == bool_)) { + std::ostringstream msg; + msg << "[" << op_name + << "] Only allowed on integer or boolean types " + "but got types " + << a.dtype() << " and " << b.dtype() << "."; + throw std::runtime_error(msg.str()); + } + auto inputs = + broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + return array( + a.shape(), + out_type, + std::make_shared(to_stream(s), op), + std::move(inputs)); +} + +array bitwise_and(const array& a, const array& b, StreamOrDevice s /* = {} */) { + return bitwise_impl(a, b, BitwiseBinary::Op::And, "bitwise_and", s); +} +array operator&(const array& a, const array& b) { + return bitwise_and(a, b); +} + +array bitwise_or(const array& a, const array& b, StreamOrDevice s /* = {} */) { + return bitwise_impl(a, b, BitwiseBinary::Op::Or, "bitwise_or", s); +} +array operator|(const array& a, const array& b) { + return bitwise_or(a, b); +} + +array bitwise_xor(const array& a, const array& b, StreamOrDevice s /* = {} */) { + return bitwise_impl(a, b, BitwiseBinary::Op::Xor, "bitwise_xor", s); +} +array operator^(const array& a, const array& b) { + return bitwise_xor(a, b); +} + +array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { + // Bit shift on bool always up-casts to uint8 + auto t = promote_types(result_type(a, b), uint8); + return bitwise_impl( + astype(a, t, s), + astype(b, t, s), + BitwiseBinary::Op::LeftShift, + "left_shift", + s); +} +array operator<<(const array& a, const array& b) { + return left_shift(a, b); +} + +array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { + // Bit shift on bool always up-casts to uint8 + auto t = promote_types(result_type(a, b), uint8); + return bitwise_impl( + astype(a, t, s), + astype(b, t, s), + BitwiseBinary::Op::RightShift, + "right_shift", + s); +} +array operator>>(const array& a, const array& b) { + return right_shift(a, b); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 15efee204..a909726f7 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1027,17 +1027,6 @@ softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) { /** Raise elements of a to the power of b element-wise */ array power(const array& a, const array& b, StreamOrDevice s = {}); -inline array operator^(const array& a, const array& b) { - return power(a, b); -} -template -array operator^(T a, const array& b) { - return power(array(a), b); -} -template -array operator^(const array& a, T b) { - return power(a, array(b)); -} /** Cumulative sum of an array. */ array cumsum( @@ -1239,6 +1228,26 @@ array number_of_elements( Dtype dtype = int32, StreamOrDevice s = {}); +/** Bitwise and. */ +array bitwise_and(const array& a, const array& b, StreamOrDevice s = {}); +array operator&(const array& a, const array& b); + +/** Bitwise inclusive or. */ +array bitwise_or(const array& a, const array& b, StreamOrDevice s = {}); +array operator|(const array& a, const array& b); + +/** Bitwise exclusive or. */ +array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {}); +array operator^(const array& a, const array& b); + +/** Shift bits to the left. */ +array left_shift(const array& a, const array& b, StreamOrDevice s = {}); +array operator<<(const array& a, const array& b); + +/** Shift bits to the right. */ +array right_shift(const array& a, const array& b, StreamOrDevice s = {}); +array operator>>(const array& a, const array& b); + /** @} */ } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3cad50422..8543daa22 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -560,6 +560,44 @@ bool AsStrided::is_equivalent(const Primitive& other) const { offset_ == a_other.offset_; } +bool BitwiseBinary::is_equivalent(const Primitive& other) const { + const BitwiseBinary& a_other = static_cast(other); + return op_ == a_other.op_; +} + +void BitwiseBinary::print(std::ostream& os) { + switch (op_) { + case BitwiseBinary::And: + os << "BitwiseAnd"; + break; + case BitwiseBinary::Or: + os << "BitwiseOr"; + break; + case BitwiseBinary::Xor: + os << "BitwiseXor"; + break; + case BitwiseBinary::LeftShift: + os << "LeftShift"; + break; + case BitwiseBinary::RightShift: + os << "RightShift"; + break; + } +} + +std::pair, std::vector> BitwiseBinary::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return { + {array( + a.shape(), + a.dtype(), + std::make_shared(stream(), op_), + {a, b})}, + {to_ax}}; +} + std::vector Broadcast::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index e90564724..390763b93 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -443,6 +443,25 @@ class AsStrided : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class BitwiseBinary : public UnaryPrimitive { + public: + enum Op { And, Or, Xor, LeftShift, RightShift }; + + explicit BitwiseBinary(Stream stream, Op op) + : UnaryPrimitive(stream), op_(op){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + bool is_equivalent(const Primitive& other) const override; + void print(std::ostream& os) override; + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + Op op_; +}; + class BlockMaskedMM : public UnaryPrimitive { public: explicit BlockMaskedMM(Stream stream, int block_size) diff --git a/python/src/array.cpp b/python/src/array.cpp index b0f1e92b5..05f4f323c 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1017,11 +1017,7 @@ void init_array(nb::module_& m) { throw std::invalid_argument( "Floating point types not allowed with bitwise and."); } - if (a.dtype() != bool_ && b.dtype() != bool_) { - throw std::invalid_argument( - "Bitwise and not yet supported for integer types."); - } - return logical_and(a, b); + return bitwise_and(a, b); }, "other"_a) .def( @@ -1036,11 +1032,7 @@ void init_array(nb::module_& m) { throw std::invalid_argument( "Floating point types not allowed with bitwise and."); } - if (a.dtype() != bool_ && b.dtype() != bool_) { - throw std::invalid_argument( - "Bitwise and not yet supported for integer types."); - } - a.overwrite_descriptor(logical_and(a, b)); + a.overwrite_descriptor(bitwise_and(a, b)); return a; }, "other"_a, @@ -1057,11 +1049,7 @@ void init_array(nb::module_& m) { throw std::invalid_argument( "Floating point types not allowed with or bitwise or."); } - if (a.dtype() != bool_ && b.dtype() != bool_) { - throw std::invalid_argument( - "Bitwise or not yet supported for integer types."); - } - return logical_or(a, b); + return bitwise_or(a, b); }, "other"_a) .def( @@ -1076,11 +1064,71 @@ void init_array(nb::module_& m) { throw std::invalid_argument( "Floating point types not allowed with or bitwise or."); } - if (a.dtype() != bool_ && b.dtype() != bool_) { - throw std::invalid_argument( - "Bitwise or not yet supported for integer types."); + a.overwrite_descriptor(bitwise_or(a, b)); + return a; + }, + "other"_a, + nb::rv_policy::none) + .def( + "__lshift__", + [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("left shift", v); } - a.overwrite_descriptor(logical_or(a, b)); + auto b = to_array(v, a.dtype()); + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { + throw std::invalid_argument( + "Floating point types not allowed with left shift."); + } + return left_shift(a, b); + }, + "other"_a) + .def( + "__ilshift__", + [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace left shift", v); + } + auto b = to_array(v, a.dtype()); + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { + throw std::invalid_argument( + "Floating point types not allowed with or left shift."); + } + a.overwrite_descriptor(left_shift(a, b)); + return a; + }, + "other"_a, + nb::rv_policy::none) + .def( + "__rshift__", + [](const array& a, const ScalarOrArray v) { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("right shift", v); + } + auto b = to_array(v, a.dtype()); + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { + throw std::invalid_argument( + "Floating point types not allowed with right shift."); + } + return right_shift(a, b); + }, + "other"_a) + .def( + "__irshift__", + [](array& a, const ScalarOrArray v) -> array& { + if (!is_comparable_with_array(v)) { + throw_invalid_operation("inplace right shift", v); + } + auto b = to_array(v, a.dtype()); + if (issubdtype(a.dtype(), inexact) || + issubdtype(b.dtype(), inexact)) { + throw std::invalid_argument( + "Floating point types not allowed with or right shift."); + } + a.overwrite_descriptor(right_shift(a, b)); return a; }, "other"_a, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 22c0a97af..fb2bebdb6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3702,8 +3702,8 @@ void init_ops(nb::module_& m) { * ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`) - * ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`) - + * ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`) + * ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`) Note: Only ``block_size=64`` and ``block_size=32`` are currently supported @@ -3897,4 +3897,132 @@ void init_ops(nb::module_& m) { &issubdtype), ""_a, ""_a); + m.def( + "bitwise_and", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return bitwise_and(a, b, s); + }, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise bitwise and. + + Take the bitwise and of two arrays with numpy-style broadcasting + semantics. Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The bitwise and ``a & b``. + )pbdoc"); + m.def( + "bitwise_or", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return bitwise_or(a, b, s); + }, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise bitwise or. + + Take the bitwise or of two arrays with numpy-style broadcasting + semantics. Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The bitwise or``a | b``. + )pbdoc"); + m.def( + "bitwise_xor", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return bitwise_xor(a, b, s); + }, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise bitwise xor. + + Take the bitwise exclusive or of two arrays with numpy-style + broadcasting semantics. Either or both input arrays can also be + scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The bitwise xor ``a ^ b``. + )pbdoc"); + m.def( + "left_shift", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return left_shift(a, b, s); + }, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise left shift. + + Shift the bits of the first input to the left by the second using + numpy-style broadcasting semantics. Either or both input arrays can + also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The bitwise left shift ``a << b``. + )pbdoc"); + m.def( + "right_shift", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return right_shift(a, b, s); + }, + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise right shift. + + Shift the bits of the first input to the right by the second using + numpy-style broadcasting semantics. Either or both input arrays can + also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The bitwise right shift ``a >> b``. + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 696060ccf..af22bd14e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2177,6 +2177,38 @@ class TestOps(mlx_tests.MLXTestCase): f"mx and np don't aggree on {a}, {b}", ) + def test_bitwise_ops(self): + types = [ + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + mx.int8, + mx.int16, + mx.int32, + mx.int64, + ] + a = mx.random.randint(0, 4096, (1000,)) + b = mx.random.randint(0, 4096, (1000,)) + for op in ["bitwise_and", "bitwise_or", "bitwise_xor"]: + for t in types: + a_mlx = a.astype(t) + b_mlx = b.astype(t) + a_np = np.array(a_mlx) + b_np = np.array(b_mlx) + out_mlx = getattr(mx, op)(a_mlx, b_mlx) + out_np = getattr(np, op)(a_np, b_np) + self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + for op in ["left_shift", "right_shift"]: + for t in types: + a_mlx = a.astype(t) + b_mlx = mx.random.randint(0, t.size, (1000,)).astype(t) + a_np = np.array(a_mlx) + b_np = np.array(b_mlx) + out_mlx = getattr(mx, op)(a_mlx, b_mlx) + out_np = getattr(np, op)(a_np, b_np) + self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 0d01f5d10..9b763ec8b 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -596,7 +596,7 @@ TEST_CASE("test op vjps") { // Test power { auto fun = [](std::vector inputs) { - return std::vector{inputs[0] ^ inputs[1]}; + return std::vector{power(inputs[0], inputs[1])}; }; auto out = vjp(fun, {array(4.0f), array(3.0f)}, {array(1.0f)}).second; CHECK_EQ(out[0].item(), 48.0f); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 663fc1cf7..26abfcfb4 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2308,29 +2308,26 @@ TEST_CASE("test pad") { TEST_CASE("test power") { CHECK_EQ(power(array(1), array(2)).item(), 1); - CHECK_EQ((array(1) ^ 2).item(), 1); - CHECK_EQ((1 ^ array(2)).item(), 1); - CHECK_EQ((array(-1) ^ 2).item(), 1); - CHECK_EQ((array(-1) ^ 3).item(), -1); + CHECK_EQ((power(array(-1), array(2))).item(), 1); + CHECK_EQ((power(array(-1), array(3))).item(), -1); - // TODO Throws but exception not caught from calling thread - // CHECK_THROWS((x^-1).item()); - - CHECK_EQ((array(true) ^ array(false)).item(), true); - CHECK_EQ((array(false) ^ array(false)).item(), true); - CHECK_EQ((array(true) ^ array(true)).item(), true); - CHECK_EQ((array(false) ^ array(true)).item(), false); + CHECK_EQ((power(array(true), array(false))).item(), true); + CHECK_EQ((power(array(false), array(false))).item(), true); + CHECK_EQ((power(array(true), array(true))).item(), true); + CHECK_EQ((power(array(false), array(true))).item(), false); auto x = array(2.0f); - CHECK_EQ((x ^ 0.5).item(), doctest::Approx(std::pow(2.0f, 0.5f))); - CHECK_EQ((x ^ 2.0f).item(), 4.0f); + CHECK_EQ( + (power(x, array(0.5))).item(), + doctest::Approx(std::pow(2.0f, 0.5f))); + CHECK_EQ(power(x, array(2.0f)).item(), 4.0f); - CHECK(std::isnan((array(-1.0f) ^ 0.5).item())); + CHECK(std::isnan((power(array(-1.0f), array(0.5))).item())); auto a = complex64_t{0.5, 0.5}; auto b = complex64_t{0.5, 0.5}; auto expected = std::pow(a, b); - auto out = (array(a) ^ array(b)).item(); + auto out = (power(array(a), array(b))).item(); CHECK(abs(out.real() - expected.real()) < 1e-7); CHECK(abs(out.imag() - expected.imag()) < 1e-7); } @@ -3230,4 +3227,4 @@ TEST_CASE("test meshgrid") { expected_one = array({1, 2, 3}, {3, 1}); CHECK(array_equal(out[0], expected_zero).item()); CHECK(array_equal(out[1], expected_one).item()); -} \ No newline at end of file +}