diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 61b1130b5..f8308c2d5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4882,8 +4882,9 @@ array bitwise_impl( const array& b, BitwiseBinary::Op op, const std::string& op_name, - const StreamOrDevice& s) { - auto out_type = promote_types(a.dtype(), b.dtype()); + const StreamOrDevice& s, + std::optional out_type_ = std::nullopt) { + auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype()); if (!(issubdtype(out_type, integer) || out_type == bool_)) { std::ostringstream msg; msg << "[" << op_name @@ -4924,62 +4925,28 @@ array operator^(const array& a, const array& b) { } array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) { - std::ostringstream msg; - msg << "[left_shift] First argument must be integer or boolean type " - << "but got type " << a.dtype() << "."; - throw std::runtime_error(msg.str()); - } - - if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) { - std::ostringstream msg; - msg << "[left_shift] Second argument must be integer or boolean type " - << "but got type " << b.dtype() << "."; - throw std::runtime_error(msg.str()); - } - auto t = result_type(a, b); if (t == bool_) { t = uint8; } - - return bitwise_impl( - astype(a, t, s), - astype(b, t, s), - BitwiseBinary::Op::LeftShift, - "left_shift", - s); + return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t); } array operator<<(const array& a, const array& b) { return left_shift(a, b); } array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) { - std::ostringstream msg; - msg << "[right_shift] First argument must be integer or boolean type " - << "but got type " << a.dtype() << "."; - throw std::runtime_error(msg.str()); - } - - if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) { - std::ostringstream msg; - msg << "[right_shift] Second argument must be integer or boolean type " - << "but got type " << b.dtype() << "."; - throw std::runtime_error(msg.str()); - } - auto t = result_type(a, b); if (t == bool_) { t = uint8; } - return bitwise_impl( astype(a, t, s), astype(b, t, s), BitwiseBinary::Op::RightShift, "right_shift", - s); + s, + t); } array operator>>(const array& a, const array& b) { return right_shift(a, b);