mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 10:51:21 +08:00
simplify
This commit is contained in:
parent
803a417ed7
commit
0f18a974a3
45
mlx/ops.cpp
45
mlx/ops.cpp
@ -4882,8 +4882,9 @@ array bitwise_impl(
|
||||
const array& b,
|
||||
BitwiseBinary::Op op,
|
||||
const std::string& op_name,
|
||||
const StreamOrDevice& s) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
const StreamOrDevice& s,
|
||||
std::optional<Dtype> out_type_ = std::nullopt) {
|
||||
auto out_type = out_type_ ? *out_type_ : promote_types(a.dtype(), b.dtype());
|
||||
if (!(issubdtype(out_type, integer) || out_type == bool_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << op_name
|
||||
@ -4924,62 +4925,28 @@ array operator^(const array& a, const array& b) {
|
||||
}
|
||||
|
||||
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[left_shift] First argument must be integer or boolean type "
|
||||
<< "but got type " << a.dtype() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[left_shift] Second argument must be integer or boolean type "
|
||||
<< "but got type " << b.dtype() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
auto t = result_type(a, b);
|
||||
if (t == bool_) {
|
||||
t = uint8;
|
||||
}
|
||||
|
||||
return bitwise_impl(
|
||||
astype(a, t, s),
|
||||
astype(b, t, s),
|
||||
BitwiseBinary::Op::LeftShift,
|
||||
"left_shift",
|
||||
s);
|
||||
return bitwise_impl(a, b, BitwiseBinary::Op::LeftShift, "left_shift", s, t);
|
||||
}
|
||||
array operator<<(const array& a, const array& b) {
|
||||
return left_shift(a, b);
|
||||
}
|
||||
|
||||
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[right_shift] First argument must be integer or boolean type "
|
||||
<< "but got type " << a.dtype() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[right_shift] Second argument must be integer or boolean type "
|
||||
<< "but got type " << b.dtype() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
auto t = result_type(a, b);
|
||||
if (t == bool_) {
|
||||
t = uint8;
|
||||
}
|
||||
|
||||
return bitwise_impl(
|
||||
astype(a, t, s),
|
||||
astype(b, t, s),
|
||||
BitwiseBinary::Op::RightShift,
|
||||
"right_shift",
|
||||
s);
|
||||
s,
|
||||
t);
|
||||
}
|
||||
array operator>>(const array& a, const array& b) {
|
||||
return right_shift(a, b);
|
||||
|
Loading…
Reference in New Issue
Block a user