Fix the error message in mx.right_shift and mx.left_shift (#2121)

* update right_shift and lef_shift

* simplify

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
1ndig0 2025-04-26 00:14:28 +08:00 committed by GitHub
parent eaf709b83e
commit 6b2d5448f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4882,8 +4882,9 @@ array bitwise_impl(
const array& b,
BitwiseBinary::Op op,
const std::string& op_name,
const StreamOrDevice& s) {
auto out_type = promote_types(a.dtype(), b.dtype());
const StreamOrDevice& s,
std::optional<Dtype> out_type_ = std::nullopt) {
auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype());
if (!(issubdtype(out_type, integer) || out_type == bool_)) {
std::ostringstream msg;
msg << "[" << op_name
@ -4928,12 +4929,7 @@ array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (t == bool_) {
t = uint8;
}
return bitwise_impl(
astype(a, t, s),
astype(b, t, s),
BitwiseBinary::Op::LeftShift,
"left_shift",
s);
return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t);
}
array operator<<(const array& a, const array& b) {
return left_shift(a, b);
@ -4949,7 +4945,8 @@ array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
astype(b, t, s),
BitwiseBinary::Op::RightShift,
"right_shift",
s);
s,
t);
}
array operator>>(const array& a, const array& b) {
return right_shift(a, b);