mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
@@ -1017,11 +1017,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
return bitwise_and(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@@ -1036,11 +1032,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
a.overwrite_descriptor(logical_and(a, b));
|
||||
a.overwrite_descriptor(bitwise_and(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
@@ -1057,11 +1049,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
return bitwise_or(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@@ -1076,11 +1064,71 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
a.overwrite_descriptor(bitwise_or(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__lshift__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("left shift", v);
|
||||
}
|
||||
a.overwrite_descriptor(logical_or(a, b));
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with left shift.");
|
||||
}
|
||||
return left_shift(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ilshift__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace left shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or left shift.");
|
||||
}
|
||||
a.overwrite_descriptor(left_shift(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__rshift__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("right shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with right shift.");
|
||||
}
|
||||
return right_shift(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__irshift__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace right shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or right shift.");
|
||||
}
|
||||
a.overwrite_descriptor(right_shift(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
|
@@ -3702,8 +3702,8 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
* ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
* ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`)
|
||||
|
||||
Note: Only ``block_size=64`` and ``block_size=32`` are currently supported
|
||||
@@ -3897,4 +3897,132 @@ void init_ops(nb::module_& m) {
|
||||
&issubdtype),
|
||||
""_a,
|
||||
""_a);
|
||||
m.def(
|
||||
"bitwise_and",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return bitwise_and(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise bitwise and.
|
||||
|
||||
Take the bitwise and of two arrays with numpy-style broadcasting
|
||||
semantics. Either or both input arrays can also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise and ``a & b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"bitwise_or",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return bitwise_or(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise bitwise or.
|
||||
|
||||
Take the bitwise or of two arrays with numpy-style broadcasting
|
||||
semantics. Either or both input arrays can also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise or``a | b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"bitwise_xor",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return bitwise_xor(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise bitwise xor.
|
||||
|
||||
Take the bitwise exclusive or of two arrays with numpy-style
|
||||
broadcasting semantics. Either or both input arrays can also be
|
||||
scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise xor ``a ^ b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"left_shift",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return left_shift(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise left shift.
|
||||
|
||||
Shift the bits of the first input to the left by the second using
|
||||
numpy-style broadcasting semantics. Either or both input arrays can
|
||||
also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise left shift ``a << b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"right_shift",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
auto [a, b] = to_arrays(a_, b_);
|
||||
return right_shift(a, b, s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Element-wise right shift.
|
||||
|
||||
Shift the bits of the first input to the right by the second using
|
||||
numpy-style broadcasting semantics. Either or both input arrays can
|
||||
also be scalars.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The bitwise right shift ``a >> b``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
Reference in New Issue
Block a user