feat: add logicalAnd and logicalOR (#386)

* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-01-08 19:00:05 +04:00
committed by GitHub
parent 022a944367
commit 73321b8097
17 changed files with 311 additions and 1 deletions

View File

@@ -41,6 +41,8 @@ DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(NotEqual)
DEFAULT(Pad)

View File

@@ -57,6 +57,8 @@ DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)

View File

@@ -8,6 +8,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/erf.h"
#include "mlx/backend/common/threefry.h"
@@ -364,6 +365,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
unary(in, out, [](auto x) { return !x; });
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@@ -131,6 +131,16 @@ struct Subtract {
template <typename T> T operator()(T x, T y) { return x - y; }
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) { return x && y; };
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) { return x || y; };
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s(
device const T* a,
@@ -377,3 +387,6 @@ instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)

View File

@@ -439,6 +439,20 @@ void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "lnot");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(
inputs,
out,
"land"); // Assume "land" is the operation identifier for logical AND
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(
inputs,
out,
"lor"); // Assume "lor" is the operation identifier for logical OR
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
}

View File

@@ -48,6 +48,8 @@ NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU(Matmul)
NO_GPU(Maximum)