Add bitwise ops (#1037)

* bitwise ops

* fix tests
This commit is contained in:
Awni Hannun
2024-04-26 22:03:42 -07:00
committed by GitHub
parent 67d1894759
commit 86f495985b
17 changed files with 568 additions and 58 deletions

View File

@@ -3702,8 +3702,8 @@ void init_ops(nb::module_& m) {
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
@@ -3897,4 +3897,132 @@ void init_ops(nb::module_& m) {
&issubdtype),
""_a,
""_a);
m.def(
"bitwise_and",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_and(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise and.
Take the bitwise and of two arrays with numpy-style broadcasting
semantics. Either or both input arrays can also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise and ``a & b``.
)pbdoc");
m.def(
"bitwise_or",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_or(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise or.
Take the bitwise or of two arrays with numpy-style broadcasting
semantics. Either or both input arrays can also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise or``a | b``.
)pbdoc");
m.def(
"bitwise_xor",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return bitwise_xor(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise bitwise xor.
Take the bitwise exclusive or of two arrays with numpy-style
broadcasting semantics. Either or both input arrays can also be
scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise xor ``a ^ b``.
)pbdoc");
m.def(
"left_shift",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return left_shift(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise left shift.
Shift the bits of the first input to the left by the second using
numpy-style broadcasting semantics. Either or both input arrays can
also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise left shift ``a << b``.
)pbdoc");
m.def(
"right_shift",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
auto [a, b] = to_arrays(a_, b_);
return right_shift(a, b, s);
},
nb::arg(),
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise right shift.
Shift the bits of the first input to the right by the second using
numpy-style broadcasting semantics. Either or both input arrays can
also be scalars.
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
Returns:
array: The bitwise right shift ``a >> b``.
)pbdoc");
}