From 5f04c0f818a05868b24e3f69db40205df4bcb75f Mon Sep 17 00:00:00 2001 From: Param Thakkar <128291516+ParamThakkar123@users.noreply.github.com> Date: Sat, 19 Apr 2025 02:58:33 +0530 Subject: [PATCH] Fixed shift operations issue (#2080) * Fixed shift operations issue * Added tests and fixes * Fixed loop syntax error * Added tests for bool * Fixed typo --- mlx/ops.cpp | 12 ++++++++---- tests/ops_tests.cpp | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2f92088aa..54ac62fef 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4915,8 +4915,10 @@ array operator^(const array& a, const array& b) { } array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - // Bit shift on bool always up-casts to uint8 - auto t = promote_types(result_type(a, b), uint8); + auto t = result_type(a, b); + if (t == bool_) { + t = uint8; + } return bitwise_impl( astype(a, t, s), astype(b, t, s), @@ -4929,8 +4931,10 @@ array operator<<(const array& a, const array& b) { } array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { - // Bit shift on bool always up-casts to uint8 - auto t = promote_types(result_type(a, b), uint8); + auto t = result_type(a, b); + if (t == bool_) { + t = uint8; + } return bitwise_impl( astype(a, t, s), astype(b, t, s), diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 356515702..de0f3352c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3874,3 +3874,41 @@ TEST_CASE("test contiguous") { CHECK(x.flags().col_contiguous); CHECK_EQ(x.strides(), decltype(x.strides()){1, 2}); } + +TEST_CASE("test bitwise shift operations") { + std::vector dtypes = { + int8, int16, int32, int64, uint8, uint16, uint32, uint64}; + + for (const auto& dtype : dtypes) { + array x = full({4}, 1, dtype); + array y = full({4}, 2, dtype); + + auto left_shift_result = left_shift(x, y); + CHECK_EQ(left_shift_result.dtype(), dtype); + CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype)) + .item()); + + auto right_shift_result = right_shift(full({4}, 4, dtype), y); + CHECK_EQ(right_shift_result.dtype(), dtype); + CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item()); + } + + array x = array({127, -128}, int8); + array y = array({1, 1}, int8); + auto left_shift_result = left_shift(x, y); + auto right_shift_result = right_shift(x, y); + + CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item()); + CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item()); + + array x_bool = full({4}, true, bool_); + array y_bool = full({4}, true, bool_); + auto left_shift_bool_result = left_shift(x_bool, y_bool); + auto right_shift_bool_result = right_shift(x_bool, y_bool); + + CHECK_EQ(left_shift_bool_result.dtype(), uint8); + CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item()); + + CHECK_EQ(right_shift_bool_result.dtype(), uint8); + CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); +} \ No newline at end of file