mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 00:31:15 +08:00
update right_shift and lef_shift
This commit is contained in:
parent
eaf709b83e
commit
803a417ed7
30
mlx/ops.cpp
30
mlx/ops.cpp
@ -4924,10 +4924,25 @@ array operator^(const array& a, const array& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[left_shift] First argument must be integer or boolean type "
|
||||||
|
<< "but got type " << a.dtype() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[left_shift] Second argument must be integer or boolean type "
|
||||||
|
<< "but got type " << b.dtype() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto t = result_type(a, b);
|
auto t = result_type(a, b);
|
||||||
if (t == bool_) {
|
if (t == bool_) {
|
||||||
t = uint8;
|
t = uint8;
|
||||||
}
|
}
|
||||||
|
|
||||||
return bitwise_impl(
|
return bitwise_impl(
|
||||||
astype(a, t, s),
|
astype(a, t, s),
|
||||||
astype(b, t, s),
|
astype(b, t, s),
|
||||||
@ -4940,10 +4955,25 @@ array operator<<(const array& a, const array& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[right_shift] First argument must be integer or boolean type "
|
||||||
|
<< "but got type " << a.dtype() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[right_shift] Second argument must be integer or boolean type "
|
||||||
|
<< "but got type " << b.dtype() << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto t = result_type(a, b);
|
auto t = result_type(a, b);
|
||||||
if (t == bool_) {
|
if (t == bool_) {
|
||||||
t = uint8;
|
t = uint8;
|
||||||
}
|
}
|
||||||
|
|
||||||
return bitwise_impl(
|
return bitwise_impl(
|
||||||
astype(a, t, s),
|
astype(a, t, s),
|
||||||
astype(b, t, s),
|
astype(b, t, s),
|
||||||
|
Loading…
Reference in New Issue
Block a user