mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
Fixed shift operations issue (#2080)
* Fixed shift operations issue * Added tests and fixes * Fixed loop syntax error * Added tests for bool * Fixed typo
This commit is contained in:
12
mlx/ops.cpp
12
mlx/ops.cpp
@@ -4915,8 +4915,10 @@ array operator^(const array& a, const array& b) {
|
||||
}
|
||||
|
||||
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Bit shift on bool always up-casts to uint8
|
||||
auto t = promote_types(result_type(a, b), uint8);
|
||||
auto t = result_type(a, b);
|
||||
if (t == bool_) {
|
||||
t = uint8;
|
||||
}
|
||||
return bitwise_impl(
|
||||
astype(a, t, s),
|
||||
astype(b, t, s),
|
||||
@@ -4929,8 +4931,10 @@ array operator<<(const array& a, const array& b) {
|
||||
}
|
||||
|
||||
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Bit shift on bool always up-casts to uint8
|
||||
auto t = promote_types(result_type(a, b), uint8);
|
||||
auto t = result_type(a, b);
|
||||
if (t == bool_) {
|
||||
t = uint8;
|
||||
}
|
||||
return bitwise_impl(
|
||||
astype(a, t, s),
|
||||
astype(b, t, s),
|
||||
|
Reference in New Issue
Block a user