This commit is contained in:
Awni Hannun 2025-04-25 06:30:52 -07:00
parent 803a417ed7
commit 0f18a974a3

View File

@ -4882,8 +4882,9 @@ array bitwise_impl(
const array& b, const array& b,
BitwiseBinary::Op op, BitwiseBinary::Op op,
const std::string& op_name, const std::string& op_name,
const StreamOrDevice& s) { const StreamOrDevice& s,
auto out_type = promote_types(a.dtype(), b.dtype()); std::optional<Dtype> out_type_ = std::nullopt) {
auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype());
if (!(issubdtype(out_type, integer) || out_type == bool_)) { if (!(issubdtype(out_type, integer) || out_type == bool_)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << op_name msg << "[" << op_name
@ -4924,62 +4925,28 @@ array operator^(const array& a, const array& b) {
} }
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
std::ostringstream msg;
msg << "[left_shift] First argument must be integer or boolean type "
<< "but got type " << a.dtype() << ".";
throw std::runtime_error(msg.str());
}
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
std::ostringstream msg;
msg << "[left_shift] Second argument must be integer or boolean type "
<< "but got type " << b.dtype() << ".";
throw std::runtime_error(msg.str());
}
auto t = result_type(a, b); auto t = result_type(a, b);
if (t == bool_) { if (t == bool_) {
t = uint8; t = uint8;
} }
return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t);
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::LeftShift,
"left_shift",
s);
} }
array operator<<(const array& a, const array& b) { array operator<<(const array& a, const array& b) {
return left_shift(a, b); return left_shift(a, b);
} }
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
std::ostringstream msg;
msg << "[right_shift] First argument must be integer or boolean type "
<< "but got type " << a.dtype() << ".";
throw std::runtime_error(msg.str());
}
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
std::ostringstream msg;
msg << "[right_shift] Second argument must be integer or boolean type "
<< "but got type " << b.dtype() << ".";
throw std::runtime_error(msg.str());
}
auto t = result_type(a, b); auto t = result_type(a, b);
if (t == bool_) { if (t == bool_) {
t = uint8; t = uint8;
} }
return bitwise_impl( return bitwise_impl(
astype(a, t, s), astype(a, t, s),
astype(b, t, s), astype(b, t, s),
BitwiseBinary::Op::RightShift, BitwiseBinary::Op::RightShift,
"right_shift", "right_shift",
s); s,
t);
} }
array operator>>(const array& a, const array& b) { array operator>>(const array& a, const array& b) {
return right_shift(a, b); return right_shift(a, b);