Add bitwise ops (#1037)

* bitwise ops

* fix tests
This commit is contained in:
Awni Hannun
2024-04-26 22:03:42 -07:00
committed by GitHub
parent 67d1894759
commit 86f495985b
17 changed files with 568 additions and 58 deletions

View File

@@ -1017,11 +1017,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
return logical_and(a, b);
return bitwise_and(a, b);
},
"other"_a)
.def(
@@ -1036,11 +1032,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with bitwise and.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise and not yet supported for integer types.");
}
a.overwrite_descriptor(logical_and(a, b));
a.overwrite_descriptor(bitwise_and(a, b));
return a;
},
"other"_a,
@@ -1057,11 +1049,7 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
}
return logical_or(a, b);
return bitwise_or(a, b);
},
"other"_a)
.def(
@@ -1076,11 +1064,71 @@ void init_array(nb::module_& m) {
throw std::invalid_argument(
"Floating point types not allowed with or bitwise or.");
}
if (a.dtype() != bool_ && b.dtype() != bool_) {
throw std::invalid_argument(
"Bitwise or not yet supported for integer types.");
a.overwrite_descriptor(bitwise_or(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def(
"__lshift__",
[](const array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("left shift", v);
}
a.overwrite_descriptor(logical_or(a, b));
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with left shift.");
}
return left_shift(a, b);
},
"other"_a)
.def(
"__ilshift__",
[](array& a, const ScalarOrArray v) -> array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace left shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with or left shift.");
}
a.overwrite_descriptor(left_shift(a, b));
return a;
},
"other"_a,
nb::rv_policy::none)
.def(
"__rshift__",
[](const array& a, const ScalarOrArray v) {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("right shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with right shift.");
}
return right_shift(a, b);
},
"other"_a)
.def(
"__irshift__",
[](array& a, const ScalarOrArray v) -> array& {
if (!is_comparable_with_array(v)) {
throw_invalid_operation("inplace right shift", v);
}
auto b = to_array(v, a.dtype());
if (issubdtype(a.dtype(), inexact) ||
issubdtype(b.dtype(), inexact)) {
throw std::invalid_argument(
"Floating point types not allowed with or right shift.");
}
a.overwrite_descriptor(right_shift(a, b));
return a;
},
"other"_a,

View File

@@ -3702,8 +3702,8 @@ void init_ops(nb::module_& m) {
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
@@ -3897,4 +3897,132 @@ void init_ops(nb::module_& m) {
&issubdtype),
""_a,
""_a);
m.def(
"bitwise_and",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_and(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise and.
Take the bitwise and of two arrays with numpy-style broadcasting
semantics. Either or both input arrays can also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise and ``a & b``.
)pbdoc");
m.def(
"bitwise_or",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_or(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise or.
Take the bitwise or of two arrays with numpy-style broadcasting
semantics. Either or both input arrays can also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise or``a | b``.
)pbdoc");
m.def(
"bitwise_xor",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_xor(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise xor.
Take the bitwise exclusive or of two arrays with numpy-style
broadcasting semantics. Either or both input arrays can also be
scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise xor ``a ^ b``.
)pbdoc");
m.def(
"left_shift",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return left_shift(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise left shift.
Shift the bits of the first input to the left by the second using
numpy-style broadcasting semantics. Either or both input arrays can
also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise left shift ``a << b``.
)pbdoc");
m.def(
"right_shift",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return right_shift(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise right shift.
Shift the bits of the first input to the right by the second using
numpy-style broadcasting semantics. Either or both input arrays can
also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise right shift ``a >> b``.
)pbdoc");
}