feat: add logicalAnd and logicalOR (#386)

* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-01-08 19:00:05 +04:00
committed by GitHub
parent 022a944367
commit 73321b8097
17 changed files with 311 additions and 1 deletions

View File

@@ -795,6 +795,44 @@ TEST_CASE("test arithmetic unary ops") {
CHECK_EQ(y.item<bool>(), true);
}
// Test logical and
{
array x(true);
array y(true);
CHECK_EQ(logical_and(x, y).item<bool>(), true);
x = array(1.0f);
y = array(1.0f);
auto z = logical_and(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), true);
x = array(0);
y = array(1.0f);
z = logical_and(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), false);
}
// Test logical or
{
array x(false);
array y(false);
CHECK_EQ(logical_or(x, y).item<bool>(), false);
x = array(1.0f);
y = array(1.0f);
auto z = logical_or(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), true);
x = array(0);
y = array(1.0f);
z = logical_or(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), true);
}
// Test abs
{
array x({-1.0f, 0.0f, 1.0f});