feat: add logicalAnd and logicalOR (#386)

* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-01-08 19:00:05 +04:00
committed by GitHub
parent 022a944367
commit 73321b8097
17 changed files with 311 additions and 1 deletions

View File

@@ -610,6 +610,28 @@ class TestOps(mlx_tests.MLXTestCase):
expected = np.logical_not(a)
self.assertTrue(np.array_equal(result, expected))
def test_logical_and(self):
a = mx.array([True, False, True, False])
b = mx.array([True, True, False, False])
result = mx.logical_and(a, b)
expected = np.logical_and(a, b)
self.assertTrue(np.array_equal(result, expected))
# test overloaded operator
result = a & b
self.assertTrue(np.array_equal(result, expected))
def test_logical_or(self):
a = mx.array([True, False, True, False])
b = mx.array([True, True, False, False])
result = mx.logical_or(a, b)
expected = np.logical_or(a, b)
self.assertTrue(np.array_equal(result, expected))
# test overloaded operator
result = a | b
self.assertTrue(np.array_equal(result, expected))
def test_square(self):
a = mx.array([0.1, 0.5, 1.0, 10.0])
result = mx.square(a)

View File

@@ -80,6 +80,8 @@ class TestVmap(mlx_tests.MLXTestCase):
"multiply",
"power",
"subtract",
"logical_or",
"logical_and",
]
for opname in ops:
with self.subTest(op=opname):