Add bitwise ops (#1037)

* bitwise ops

* fix tests
This commit is contained in:
Awni Hannun
2024-04-26 22:03:42 -07:00
committed by GitHub
parent 67d1894759
commit 86f495985b
17 changed files with 568 additions and 58 deletions

View File

@@ -1017,11 +1017,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
return logical_and(a, b);
return bitwise_and(a, b);
},
"other"_a)
.def(
@@ -1036,11 +1032,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
a.overwrite_descriptor(logical_and(a, b));
a.overwrite_descriptor(bitwise_and(a, b));
return a;
},
"other"_a,
@@ -1057,11 +1049,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
return logical_or(a, b);
return bitwise_or(a, b);
},
"other"_a)
.def(
@@ -1076,11 +1064,71 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
a.overwrite_descriptor(bitwise_or(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def(
"__lshift__",
[](const array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("left shift", v);
}
a.overwrite_descriptor(logical_or(a, b));
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with left shift.");
}
return left_shift(a, b);
},
"other"_a)
.def(
"__ilshift__",
[](array& a, const ScalarOrArray v) -> array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace left shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with or left shift.");
}
a.overwrite_descriptor(left_shift(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def(
"__rshift__",
[](const array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("right shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with right shift.");
}
return right_shift(a, b);
},
"other"_a)
.def(
"__irshift__",
[](array& a, const ScalarOrArray v) -> array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace right shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with or right shift.");
}
a.overwrite_descriptor(right_shift(a, b));
return a;
},
"other"_a,