From 803a417ed7a7c3ac90de766c9b076863c4a554b3 Mon Sep 17 00:00:00 2001 From: Redmept1on <1090891928@qq.com> Date: Fri, 25 Apr 2025 11:29:25 +0800 Subject: [PATCH] update right_shift and lef_shift --- mlx/ops.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c2aa4786f..61b1130b5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4924,10 +4924,25 @@ array operator^(const array& a, const array& b) { } array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) { + std::ostringstream msg; + msg << "[left_shift] First argument must be integer or boolean type " + << "but got type " << a.dtype() << "."; + throw std::runtime_error(msg.str()); + } + + if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) { + std::ostringstream msg; + msg << "[left_shift] Second argument must be integer or boolean type " + << "but got type " << b.dtype() << "."; + throw std::runtime_error(msg.str()); + } + auto t = result_type(a, b); if (t == bool_) { t = uint8; } + return bitwise_impl( astype(a, t, s), astype(b, t, s), @@ -4940,10 +4955,25 @@ array operator<<(const array& a, const array& b) { } array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (!(issubdtype(a.dtype(), integer) || a.dtype() == bool_)) { + std::ostringstream msg; + msg << "[right_shift] First argument must be integer or boolean type " + << "but got type " << a.dtype() << "."; + throw std::runtime_error(msg.str()); + } + + if (!(issubdtype(b.dtype(), integer) || b.dtype() == bool_)) { + std::ostringstream msg; + msg << "[right_shift] Second argument must be integer or boolean type " + << "but got type " << b.dtype() << "."; + throw std::runtime_error(msg.str()); + } + auto t = result_type(a, b); if (t == bool_) { t = uint8; } + return bitwise_impl( astype(a, t, s), astype(b, t, s),