diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c2aa4786f..61b1130b5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4924,10 +4924,25 @@ array operator^(const array& a, const array& b) { } array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) { + std::ostringstream msg; + msg << "[left_shift] First argument must be integer or boolean type " + << "but got type " << a.dtype() << "."; + throw std::runtime_error(msg.str()); + } + + if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) { + std::ostringstream msg; + msg << "[left_shift] Second argument must be integer or boolean type " + << "but got type " << b.dtype() << "."; + throw std::runtime_error(msg.str()); + } + auto t = result_type(a, b); if (t == bool_) { t = uint8; } + return bitwise_impl( astype(a, t, s), astype(b, t, s), @@ -4940,10 +4955,25 @@ array operator<<(const array& a, const array& b) { } array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) { + std::ostringstream msg; + msg << "[right_shift] First argument must be integer or boolean type " + << "but got type " << a.dtype() << "."; + throw std::runtime_error(msg.str()); + } + + if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) { + std::ostringstream msg; + msg << "[right_shift] Second argument must be integer or boolean type " + << "but got type " << b.dtype() << "."; + throw std::runtime_error(msg.str()); + } + auto t = result_type(a, b); if (t == bool_) { t = uint8; } + return bitwise_impl( astype(a, t, s), astype(b, t, s),