mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-25 12:48:14 +08:00 
			
		
		
		
	feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR * run pre-commit * Refactor logical_and and logical_or functions * Add acknowledgement * Add logical AND and logical OR operators * Refactor logical_and and logical_or functions * Add support for logical operators on bool arrays * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Add logical AND and OR operators for arrays and scalars * Refactor vjp and jvp methods in primitives.cpp * Add overloaded operators for logical AND and OR * format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -41,6 +41,8 @@ DEFAULT(Less) | ||||
| DEFAULT(LessEqual) | ||||
| DEFAULT(Load) | ||||
| DEFAULT(LogicalNot) | ||||
| DEFAULT(LogicalAnd) | ||||
| DEFAULT(LogicalOr) | ||||
| DEFAULT(LogAddExp) | ||||
| DEFAULT(NotEqual) | ||||
| DEFAULT(Pad) | ||||
|   | ||||
| @@ -57,6 +57,8 @@ DEFAULT(Load) | ||||
| DEFAULT(Log) | ||||
| DEFAULT(Log1p) | ||||
| DEFAULT(LogicalNot) | ||||
| DEFAULT(LogicalAnd) | ||||
| DEFAULT(LogicalOr) | ||||
| DEFAULT(LogAddExp) | ||||
| DEFAULT(Maximum) | ||||
| DEFAULT(Minimum) | ||||
|   | ||||
| @@ -8,6 +8,7 @@ | ||||
|  | ||||
| #include "mlx/allocator.h" | ||||
| #include "mlx/backend/common/arange.h" | ||||
| #include "mlx/backend/common/binary.h" | ||||
| #include "mlx/backend/common/copy.h" | ||||
| #include "mlx/backend/common/erf.h" | ||||
| #include "mlx/backend/common/threefry.h" | ||||
| @@ -364,6 +365,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) { | ||||
|   unary(in, out, [](auto x) { return !x; }); | ||||
| } | ||||
|  | ||||
| void LogicalAnd::eval(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() == 2); // LogicalAnd requires two input arrays | ||||
|   auto& in1 = inputs[0]; | ||||
|   auto& in2 = inputs[1]; | ||||
|   binary(in1, in2, out, [](auto x, auto y) { return x && y; }); | ||||
| } | ||||
|  | ||||
| void LogicalOr::eval(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() == 2); // LogicalOr requires two input arrays | ||||
|   auto& in1 = inputs[0]; | ||||
|   auto& in2 = inputs[1]; | ||||
|   binary(in1, in2, out, [](auto x, auto y) { return x || y; }); | ||||
| } | ||||
|  | ||||
| void Negative::eval(const std::vector<array>& inputs, array& out) { | ||||
|   assert(inputs.size() == 1); | ||||
|   auto& in = inputs[0]; | ||||
|   | ||||
| @@ -131,6 +131,16 @@ struct Subtract { | ||||
|   template <typename T> T operator()(T x, T y) { return x - y; } | ||||
| }; | ||||
|  | ||||
| struct LogicalAnd { | ||||
|     template <typename T> | ||||
|     T operator()(T x, T y) { return x && y; }; | ||||
| }; | ||||
|  | ||||
| struct LogicalOr { | ||||
|     template <typename T> | ||||
|     T operator()(T x, T y) { return x || y; }; | ||||
| }; | ||||
|  | ||||
| template <typename T, typename U, typename Op> | ||||
| [[kernel]] void binary_op_s2s( | ||||
|     device const T* a, | ||||
| @@ -377,3 +387,6 @@ instantiate_binary_all(naneq, float16, half, bool, NaNEqual) | ||||
| instantiate_binary_all(naneq, float32, float, bool, NaNEqual) | ||||
| instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual) | ||||
| instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) | ||||
|  | ||||
| instantiate_binary_all(lor, bool_, bool, bool, LogicalOr) | ||||
| instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) | ||||
| @@ -439,6 +439,20 @@ void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   unary_op(inputs, out, "lnot"); | ||||
| } | ||||
|  | ||||
| void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   binary_op( | ||||
|       inputs, | ||||
|       out, | ||||
|       "land"); // Assume "land" is the operation identifier for logical AND | ||||
| } | ||||
|  | ||||
| void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   binary_op( | ||||
|       inputs, | ||||
|       out, | ||||
|       "lor"); // Assume "lor" is the operation identifier for logical OR | ||||
| } | ||||
|  | ||||
| void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   binary_op(inputs, out, "lae"); | ||||
| } | ||||
|   | ||||
| @@ -48,6 +48,8 @@ NO_GPU(Load) | ||||
| NO_GPU(Log) | ||||
| NO_GPU(Log1p) | ||||
| NO_GPU(LogicalNot) | ||||
| NO_GPU(LogicalAnd) | ||||
| NO_GPU(LogicalOr) | ||||
| NO_GPU(LogAddExp) | ||||
| NO_GPU(Matmul) | ||||
| NO_GPU(Maximum) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nripesh Niketan
					Nripesh Niketan