mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 14:24:44 +08:00
feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR * run pre-commit * Refactor logical_and and logical_or functions * Add acknowledgement * Add logical AND and logical OR operators * Refactor logical_and and logical_or functions * Add support for logical operators on bool arrays * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Add logical AND and OR operators for arrays and scalars * Refactor vjp and jvp methods in primitives.cpp * Add overloaded operators for logical AND and OR * format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -610,6 +610,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = np.logical_not(a)
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
def test_logical_and(self):
|
||||
a = mx.array([True, False, True, False])
|
||||
b = mx.array([True, True, False, False])
|
||||
result = mx.logical_and(a, b)
|
||||
expected = np.logical_and(a, b)
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
# test overloaded operator
|
||||
result = a & b
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
def test_logical_or(self):
|
||||
a = mx.array([True, False, True, False])
|
||||
b = mx.array([True, True, False, False])
|
||||
result = mx.logical_or(a, b)
|
||||
expected = np.logical_or(a, b)
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
# test overloaded operator
|
||||
result = a | b
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
def test_square(self):
|
||||
a = mx.array([0.1, 0.5, 1.0, 10.0])
|
||||
result = mx.square(a)
|
||||
|
@@ -80,6 +80,8 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
"multiply",
|
||||
"power",
|
||||
"subtract",
|
||||
"logical_or",
|
||||
"logical_and",
|
||||
]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
|
Reference in New Issue
Block a user