feat: add logicalAnd and logicalOR (#386)

* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-01-08 19:00:05 +04:00
committed by GitHub
parent 022a944367
commit 73321b8097
17 changed files with 311 additions and 1 deletions

View File

@@ -764,6 +764,18 @@ void init_array(py::module_& m) {
return power(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__and__",
[](const array& a, const ScalarOrArray v) {
return logical_and(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__or__",
[](const array& a, const ScalarOrArray v) {
return logical_or(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"flatten",
[](const array& a,

View File

@@ -670,6 +670,51 @@ void init_ops(py::module_& m) {
Returns:
array: The boolean array containing the logical not of ``a``.
)pbdoc");
m.def(
"logical_and",
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
return logical_and(to_array(a), to_array(b), s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logical and.
Args:
a (array): First input array or scalar.
b (array): Second input array or scalar.
Returns:
array: The boolean array containing the logical and of ``a`` and ``b``.
)pbdoc");
m.def(
"logical_or",
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
return logical_or(to_array(a), to_array(b), s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logical or.
Args:
a (array): First input array or scalar.
b (array): Second input array or scalar.
Returns:
array: The boolean array containing the logical or of ``a`` and ``b``.
)pbdoc");
m.def(
"logaddexp",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {