mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
@@ -1017,11 +1017,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
return bitwise_and(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@@ -1036,11 +1032,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with bitwise and.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise and not yet supported for integer types.");
|
||||
}
|
||||
a.overwrite_descriptor(logical_and(a, b));
|
||||
a.overwrite_descriptor(bitwise_and(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
@@ -1057,11 +1049,7 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
return bitwise_or(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
@@ -1076,11 +1064,71 @@ void init_array(nb::module_& m) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or bitwise or.");
|
||||
}
|
||||
if (a.dtype() != bool_ && b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"Bitwise or not yet supported for integer types.");
|
||||
a.overwrite_descriptor(bitwise_or(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__lshift__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("left shift", v);
|
||||
}
|
||||
a.overwrite_descriptor(logical_or(a, b));
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with left shift.");
|
||||
}
|
||||
return left_shift(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ilshift__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace left shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or left shift.");
|
||||
}
|
||||
a.overwrite_descriptor(left_shift(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def(
|
||||
"__rshift__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("right shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with right shift.");
|
||||
}
|
||||
return right_shift(a, b);
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__irshift__",
|
||||
[](array& a, const ScalarOrArray v) -> array& {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
throw_invalid_operation("inplace right shift", v);
|
||||
}
|
||||
auto b = to_array(v, a.dtype());
|
||||
if (issubdtype(a.dtype(), inexact) ||
|
||||
issubdtype(b.dtype(), inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"Floating point types not allowed with or right shift.");
|
||||
}
|
||||
a.overwrite_descriptor(right_shift(a, b));
|
||||
return a;
|
||||
},
|
||||
"other"_a,
|
||||
|
||||
Reference in New Issue
Block a user