update right_shift and lef_shift

This commit is contained in:
Redmept1on 2025-04-25 11:29:25 +08:00
parent eaf709b83e
commit 803a417ed7

View File

@ -4924,10 +4924,25 @@ array operator^(const array& a, const array& b) {
} }
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
std::ostringstream msg;
msg << "[left_shift] First argument must be integer or boolean type "
<< "but got type " << a.dtype() << ".";
throw std::runtime_error(msg.str());
}
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
std::ostringstream msg;
msg << "[left_shift] Second argument must be integer or boolean type "
<< "but got type " << b.dtype() << ".";
throw std::runtime_error(msg.str());
}
auto t = result_type(a, b); auto t = result_type(a, b);
if (t == bool_) { if (t == bool_) {
t = uint8; t = uint8;
} }
return bitwise_impl( return bitwise_impl(
astype(a, t, s), astype(a, t, s),
astype(b, t, s), astype(b, t, s),
@ -4940,10 +4955,25 @@ array operator<<(const array& a, const array& b) {
} }
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
std::ostringstream msg;
msg << "[right_shift] First argument must be integer or boolean type "
<< "but got type " << a.dtype() << ".";
throw std::runtime_error(msg.str());
}
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
std::ostringstream msg;
msg << "[right_shift] Second argument must be integer or boolean type "
<< "but got type " << b.dtype() << ".";
throw std::runtime_error(msg.str());
}
auto t = result_type(a, b); auto t = result_type(a, b);
if (t == bool_) { if (t == bool_) {
t = uint8; t = uint8;
} }
return bitwise_impl( return bitwise_impl(
astype(a, t, s), astype(a, t, s),
astype(b, t, s), astype(b, t, s),