Fixed shift operations issue (#2080)

* Fixed shift operations issue

* Added tests and fixes

* Fixed loop syntax error

* Added tests for bool

* Fixed typo
This commit is contained in:
Param Thakkar 2025-04-19 02:58:33 +05:30 committed by GitHub
parent 55935ccae7
commit 5f04c0f818
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 4 deletions

View File

@ -4915,8 +4915,10 @@ array operator^(const array& a, const array& b) {
} }
array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array left_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8 auto t = result_type(a, b);
auto t = promote_types(result_type(a, b), uint8); if (t == bool_) {
t = uint8;
}
return bitwise_impl( return bitwise_impl(
astype(a, t, s), astype(a, t, s),
astype(b, t, s), astype(b, t, s),
@ -4929,8 +4931,10 @@ array operator<<(const array& a, const array& b) {
} }
array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) { array right_shift(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Bit shift on bool always up-casts to uint8 auto t = result_type(a, b);
auto t = promote_types(result_type(a, b), uint8); if (t == bool_) {
t = uint8;
}
return bitwise_impl( return bitwise_impl(
astype(a, t, s), astype(a, t, s),
astype(b, t, s), astype(b, t, s),

View File

@ -3874,3 +3874,41 @@ TEST_CASE("test contiguous") {
CHECK(x.flags().col_contiguous); CHECK(x.flags().col_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2}); CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
} }
TEST_CASE("test bitwise shift operations") {
std::vector<Dtype> dtypes = {
int8, int16, int32, int64, uint8, uint16, uint32, uint64};
for (const auto& dtype : dtypes) {
array x = full({4}, 1, dtype);
array y = full({4}, 2, dtype);
auto left_shift_result = left_shift(x, y);
CHECK_EQ(left_shift_result.dtype(), dtype);
CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype))
.item<bool>());
auto right_shift_result = right_shift(full({4}, 4, dtype), y);
CHECK_EQ(right_shift_result.dtype(), dtype);
CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item<bool>());
}
array x = array({127, -128}, int8);
array y = array({1, 1}, int8);
auto left_shift_result = left_shift(x, y);
auto right_shift_result = right_shift(x, y);
CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item<bool>());
CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item<bool>());
array x_bool = full({4}, true, bool_);
array y_bool = full({4}, true, bool_);
auto left_shift_bool_result = left_shift(x_bool, y_bool);
auto right_shift_bool_result = right_shift(x_bool, y_bool);
CHECK_EQ(left_shift_bool_result.dtype(), uint8);
CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item<bool>());
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
}