mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR * run pre-commit * Refactor logical_and and logical_or functions * Add acknowledgement * Add logical AND and logical OR operators * Refactor logical_and and logical_or functions * Add support for logical operators on bool arrays * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Add logical AND and OR operators for arrays and scalars * Refactor vjp and jvp methods in primitives.cpp * Add overloaded operators for logical AND and OR * format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -764,6 +764,18 @@ void init_array(py::module_& m) { | ||||
|             return power(a, to_array(v, a.dtype())); | ||||
|           }, | ||||
|           "other"_a) | ||||
|       .def( | ||||
|           "__and__", | ||||
|           [](const array& a, const ScalarOrArray v) { | ||||
|             return logical_and(a, to_array(v, a.dtype())); | ||||
|           }, | ||||
|           "other"_a) | ||||
|       .def( | ||||
|           "__or__", | ||||
|           [](const array& a, const ScalarOrArray v) { | ||||
|             return logical_or(a, to_array(v, a.dtype())); | ||||
|           }, | ||||
|           "other"_a) | ||||
|       .def( | ||||
|           "flatten", | ||||
|           [](const array& a, | ||||
|   | ||||
| @@ -670,6 +670,51 @@ void init_ops(py::module_& m) { | ||||
|         Returns: | ||||
|             array: The boolean array containing the logical not of ``a``. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "logical_and", | ||||
|       [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { | ||||
|         return logical_and(to_array(a), to_array(b), s); | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "b"_a, | ||||
|       py::pos_only(), | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|         Element-wise logical and. | ||||
|  | ||||
|         Args: | ||||
|             a (array): First input array or scalar. | ||||
|             b (array): Second input array or scalar. | ||||
|  | ||||
|         Returns: | ||||
|             array: The boolean array containing the logical and of ``a`` and ``b``. | ||||
|     )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "logical_or", | ||||
|       [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { | ||||
|         return logical_or(to_array(a), to_array(b), s); | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "b"_a, | ||||
|       py::pos_only(), | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|         Element-wise logical or. | ||||
|  | ||||
|         Args: | ||||
|             a (array): First input array or scalar. | ||||
|             b (array): Second input array or scalar. | ||||
|  | ||||
|         Returns: | ||||
|             array: The boolean array containing the logical or of ``a`` and ``b``. | ||||
|     )pbdoc"); | ||||
|   m.def( | ||||
|       "logaddexp", | ||||
|       [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { | ||||
|   | ||||
| @@ -610,6 +610,28 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         expected = np.logical_not(a) | ||||
|         self.assertTrue(np.array_equal(result, expected)) | ||||
|  | ||||
|     def test_logical_and(self): | ||||
|         a = mx.array([True, False, True, False]) | ||||
|         b = mx.array([True, True, False, False]) | ||||
|         result = mx.logical_and(a, b) | ||||
|         expected = np.logical_and(a, b) | ||||
|         self.assertTrue(np.array_equal(result, expected)) | ||||
|  | ||||
|         # test overloaded operator | ||||
|         result = a & b | ||||
|         self.assertTrue(np.array_equal(result, expected)) | ||||
|  | ||||
|     def test_logical_or(self): | ||||
|         a = mx.array([True, False, True, False]) | ||||
|         b = mx.array([True, True, False, False]) | ||||
|         result = mx.logical_or(a, b) | ||||
|         expected = np.logical_or(a, b) | ||||
|         self.assertTrue(np.array_equal(result, expected)) | ||||
|  | ||||
|         # test overloaded operator | ||||
|         result = a | b | ||||
|         self.assertTrue(np.array_equal(result, expected)) | ||||
|  | ||||
|     def test_square(self): | ||||
|         a = mx.array([0.1, 0.5, 1.0, 10.0]) | ||||
|         result = mx.square(a) | ||||
|   | ||||
| @@ -80,6 +80,8 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|             "multiply", | ||||
|             "power", | ||||
|             "subtract", | ||||
|             "logical_or", | ||||
|             "logical_and", | ||||
|         ] | ||||
|         for opname in ops: | ||||
|             with self.subTest(op=opname): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nripesh Niketan
					Nripesh Niketan