feat: add logicalAnd and logicalOR (#386)

* feat: add logicalAnd and logicalOR

* run pre-commit

* Refactor logical_and and logical_or functions

* Add acknowledgement

* Add logical AND and logical OR operators

* Refactor logical_and and logical_or functions

* Add support for logical operators on bool arrays

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Add logical AND and OR operators for arrays and scalars

* Refactor vjp and jvp methods in primitives.cpp

* Add overloaded operators for logical AND and OR

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan 2024-01-08 19:00:05 +04:00 committed by GitHub
parent 022a944367
commit 73321b8097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 311 additions and 1 deletions

View File

@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support

View File

@ -60,6 +60,8 @@ Operations
log1p log1p
logaddexp logaddexp
logical_not logical_not
logical_and
logical_or
logsumexp logsumexp
matmul matmul
max max

View File

@ -41,6 +41,8 @@ DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load) DEFAULT(Load)
DEFAULT(LogicalNot) DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp) DEFAULT(LogAddExp)
DEFAULT(NotEqual) DEFAULT(NotEqual)
DEFAULT(Pad) DEFAULT(Pad)

View File

@ -57,6 +57,8 @@ DEFAULT(Load)
DEFAULT(Log) DEFAULT(Log)
DEFAULT(Log1p) DEFAULT(Log1p)
DEFAULT(LogicalNot) DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp) DEFAULT(LogAddExp)
DEFAULT(Maximum) DEFAULT(Maximum)
DEFAULT(Minimum) DEFAULT(Minimum)

View File

@ -8,6 +8,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/arange.h" #include "mlx/backend/common/arange.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/erf.h" #include "mlx/backend/common/erf.h"
#include "mlx/backend/common/threefry.h" #include "mlx/backend/common/threefry.h"
@ -364,6 +365,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
unary(in, out, [](auto x) { return !x; }); unary(in, out, [](auto x) { return !x; });
} }
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
}
void Negative::eval(const std::vector<array>& inputs, array& out) { void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];

View File

@ -131,6 +131,16 @@ struct Subtract {
template <typename T> T operator()(T x, T y) { return x - y; } template <typename T> T operator()(T x, T y) { return x - y; }
}; };
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) { return x && y; };
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) { return x || y; };
};
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s( [[kernel]] void binary_op_s2s(
device const T* a, device const T* a,
@ -377,3 +387,6 @@ instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
instantiate_binary_all(naneq, float32, float, bool, NaNEqual) instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual) instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)

View File

@ -439,6 +439,20 @@ void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "lnot"); unary_op(inputs, out, "lnot");
} }
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(
inputs,
out,
"land"); // Assume "land" is the operation identifier for logical AND
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(
inputs,
out,
"lor"); // Assume "lor" is the operation identifier for logical OR
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) { void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae"); binary_op(inputs, out, "lae");
} }

View File

@ -48,6 +48,8 @@ NO_GPU(Load)
NO_GPU(Log) NO_GPU(Log)
NO_GPU(Log1p) NO_GPU(Log1p)
NO_GPU(LogicalNot) NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp) NO_GPU(LogAddExp)
NO_GPU(Matmul) NO_GPU(Matmul)
NO_GPU(Maximum) NO_GPU(Maximum)

View File

@ -1608,6 +1608,43 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
{astype(a, bool_, s)}); {astype(a, bool_, s)});
} }
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
return array(
inputs[0].shape(),
bool_,
std::make_unique<LogicalAnd>(to_stream(s)),
inputs);
}
array operator&&(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
}
return logical_and(a, b);
}
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
return array(
inputs[0].shape(),
bool_,
std::make_unique<LogicalOr>(to_stream(s)),
inputs);
}
array operator||(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument(
"[operator||] is only supported for bool arrays.");
}
return logical_or(a, b);
}
array reciprocal(const array& a, StreamOrDevice s /* = {} */) { array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
return divide(array(1.0f, dtype), a, to_stream(s)); return divide(array(1.0f, dtype), a, to_stream(s));

View File

@ -668,6 +668,14 @@ array sign(const array& a, StreamOrDevice s = {});
/** Logical not of an array */ /** Logical not of an array */
array logical_not(const array& a, StreamOrDevice s = {}); array logical_not(const array& a, StreamOrDevice s = {});
/** Logical and of two arrays */
array logical_and(const array& a, const array& b, StreamOrDevice s = {});
array operator&&(const array& a, const array& b);
/** Logical or of two arrays */
array logical_or(const array& a, const array& b, StreamOrDevice s = {});
array operator||(const array& a, const array& b);
/** The reciprocal (1/x) of the elements in an array. */ /** The reciprocal (1/x) of the elements in an array. */
array reciprocal(const array& a, StreamOrDevice s = {}); array reciprocal(const array& a, StreamOrDevice s = {});

View File

@ -1338,6 +1338,64 @@ std::pair<array, int> LogicalNot::vmap(
return {logical_not(inputs[0], stream()), axes[0]}; return {logical_not(inputs[0], stream()), axes[0]};
} }
std::vector<array> LogicalAnd::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
assert(primals.size() == 2);
return {zeros_like(cotan, stream()), zeros_like(cotan, stream())};
}
array LogicalAnd::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 2);
assert(argnums.size() <= 2);
return zeros_like(primals[0], stream());
}
std::pair<array, int> LogicalAnd::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 2);
assert(axes.size() == 2);
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {logical_and(a, b, stream()), to_ax};
}
std::vector<array> LogicalOr::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
assert(primals.size() == 2);
return {zeros_like(cotan, stream()), zeros_like(cotan, stream())};
}
array LogicalOr::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 2);
assert(argnums.size() <= 2);
return zeros_like(primals[0], stream());
}
std::pair<array, int> LogicalOr::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 2);
assert(axes.size() == 2);
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {logical_or(a, b, stream()), to_ax};
}
std::vector<array> LogAddExp::vjp( std::vector<array> LogAddExp::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const array& cotan,

View File

@ -903,6 +903,44 @@ class LogicalNot : public Primitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class LogicalAnd : public Primitive {
public:
explicit LogicalAnd(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_GRADS()
DEFINE_PRINT(LogicalAnd)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class LogicalOr : public Primitive {
public:
explicit LogicalOr(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_GRADS()
DEFINE_PRINT(LogicalOr)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class LogAddExp : public Primitive { class LogAddExp : public Primitive {
public: public:
explicit LogAddExp(Stream stream) : Primitive(stream){}; explicit LogAddExp(Stream stream) : Primitive(stream){};

View File

@ -764,6 +764,18 @@ void init_array(py::module_& m) {
return power(a, to_array(v, a.dtype())); return power(a, to_array(v, a.dtype()));
}, },
"other"_a) "other"_a)
.def(
"__and__",
[](const array& a, const ScalarOrArray v) {
return logical_and(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__or__",
[](const array& a, const ScalarOrArray v) {
return logical_or(a, to_array(v, a.dtype()));
},
"other"_a)
.def( .def(
"flatten", "flatten",
[](const array& a, [](const array& a,

View File

@ -670,6 +670,51 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: The boolean array containing the logical not of ``a``. array: The boolean array containing the logical not of ``a``.
)pbdoc"); )pbdoc");
m.def(
"logical_and",
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
return logical_and(to_array(a), to_array(b), s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logical and.
Args:
a (array): First input array or scalar.
b (array): Second input array or scalar.
Returns:
array: The boolean array containing the logical and of ``a`` and ``b``.
)pbdoc");
m.def(
"logical_or",
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
return logical_or(to_array(a), to_array(b), s);
},
"a"_a,
"b"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
Element-wise logical or.
Args:
a (array): First input array or scalar.
b (array): Second input array or scalar.
Returns:
array: The boolean array containing the logical or of ``a`` and ``b``.
)pbdoc");
m.def( m.def(
"logaddexp", "logaddexp",
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {

View File

@ -610,6 +610,28 @@ class TestOps(mlx_tests.MLXTestCase):
expected = np.logical_not(a) expected = np.logical_not(a)
self.assertTrue(np.array_equal(result, expected)) self.assertTrue(np.array_equal(result, expected))
def test_logical_and(self):
a = mx.array([True, False, True, False])
b = mx.array([True, True, False, False])
result = mx.logical_and(a, b)
expected = np.logical_and(a, b)
self.assertTrue(np.array_equal(result, expected))
# test overloaded operator
result = a & b
self.assertTrue(np.array_equal(result, expected))
def test_logical_or(self):
a = mx.array([True, False, True, False])
b = mx.array([True, True, False, False])
result = mx.logical_or(a, b)
expected = np.logical_or(a, b)
self.assertTrue(np.array_equal(result, expected))
# test overloaded operator
result = a | b
self.assertTrue(np.array_equal(result, expected))
def test_square(self): def test_square(self):
a = mx.array([0.1, 0.5, 1.0, 10.0]) a = mx.array([0.1, 0.5, 1.0, 10.0])
result = mx.square(a) result = mx.square(a)

View File

@ -80,6 +80,8 @@ class TestVmap(mlx_tests.MLXTestCase):
"multiply", "multiply",
"power", "power",
"subtract", "subtract",
"logical_or",
"logical_and",
] ]
for opname in ops: for opname in ops:
with self.subTest(op=opname): with self.subTest(op=opname):

View File

@ -795,6 +795,44 @@ TEST_CASE("test arithmetic unary ops") {
CHECK_EQ(y.item<bool>(), true); CHECK_EQ(y.item<bool>(), true);
} }
// Test logical and
{
array x(true);
array y(true);
CHECK_EQ(logical_and(x, y).item<bool>(), true);
x = array(1.0f);
y = array(1.0f);
auto z = logical_and(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), true);
x = array(0);
y = array(1.0f);
z = logical_and(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), false);
}
// Test logical or
{
array x(false);
array y(false);
CHECK_EQ(logical_or(x, y).item<bool>(), false);
x = array(1.0f);
y = array(1.0f);
auto z = logical_or(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), true);
x = array(0);
y = array(1.0f);
z = logical_or(x, y);
CHECK_EQ(z.dtype(), bool_);
CHECK_EQ(z.item<bool>(), true);
}
// Test abs // Test abs
{ {
array x({-1.0f, 0.0f, 1.0f}); array x({-1.0f, 0.0f, 1.0f});