From 73321b8097449f4739d8b3a6b365e64c15610e44 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Mon, 8 Jan 2024 19:00:05 +0400 Subject: [PATCH] feat: add logicalAnd and logicalOR (#386) * feat: add logicalAnd and logicalOR * run pre-commit * Refactor logical_and and logical_or functions * Add acknowledgement * Add logical AND and logical OR operators * Refactor logical_and and logical_or functions * Add support for logical operators on bool arrays * Update mlx/ops.cpp Co-authored-by: Awni Hannun * Update mlx/ops.cpp Co-authored-by: Awni Hannun * Add logical AND and OR operators for arrays and scalars * Refactor vjp and jvp methods in primitives.cpp * Add overloaded operators for logical AND and OR * format --------- Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 2 +- docs/src/python/ops.rst | 2 + mlx/backend/accelerate/primitives.cpp | 2 + mlx/backend/common/default_primitives.cpp | 2 + mlx/backend/common/primitives.cpp | 15 ++++++ mlx/backend/metal/kernels/binary.metal | 13 +++++ mlx/backend/metal/primitives.cpp | 14 ++++++ mlx/backend/no_metal/primitives.cpp | 2 + mlx/ops.cpp | 37 +++++++++++++++ mlx/ops.h | 8 ++++ mlx/primitives.cpp | 58 +++++++++++++++++++++++ mlx/primitives.h | 38 +++++++++++++++ python/src/array.cpp | 12 +++++ python/src/ops.cpp | 45 ++++++++++++++++++ python/tests/test_ops.py | 22 +++++++++ python/tests/test_vmap.py | 2 + tests/ops_tests.cpp | 38 +++++++++++++++ 17 files changed, 311 insertions(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index dce0ddc81..bb16d4b1f 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: -- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. +- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index ffeee71da..0eff8b7c6 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -60,6 +60,8 @@ Operations log1p logaddexp logical_not + logical_and + logical_or logsumexp matmul max diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 912c460e6..335f26990 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -41,6 +41,8 @@ DEFAULT(Less) DEFAULT(LessEqual) DEFAULT(Load) DEFAULT(LogicalNot) +DEFAULT(LogicalAnd) +DEFAULT(LogicalOr) DEFAULT(LogAddExp) DEFAULT(NotEqual) DEFAULT(Pad) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 4c499aa90..917286ea7 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -57,6 +57,8 @@ DEFAULT(Load) DEFAULT(Log) DEFAULT(Log1p) DEFAULT(LogicalNot) +DEFAULT(LogicalAnd) +DEFAULT(LogicalOr) DEFAULT(LogAddExp) DEFAULT(Maximum) DEFAULT(Minimum) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 2218c708c..889a9841b 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -8,6 +8,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/arange.h" +#include "mlx/backend/common/binary.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/erf.h" #include "mlx/backend/common/threefry.h" @@ -364,6 +365,20 @@ void LogicalNot::eval(const std::vector& inputs, array& out) { unary(in, out, [](auto x) { return !x; }); } +void LogicalAnd::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); // LogicalAnd requires two input arrays + auto& in1 = inputs[0]; + auto& in2 = inputs[1]; + binary(in1, in2, out, [](auto x, auto y) { return x && y; }); +} + +void LogicalOr::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); // LogicalOr requires two input arrays + auto& in1 = inputs[0]; + auto& in2 = inputs[1]; + binary(in1, in2, out, [](auto x, auto y) { return x || y; }); +} + void Negative::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index a4f599064..12b6d1b1b 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -131,6 +131,16 @@ struct Subtract { template T operator()(T x, T y) { return x - y; } }; +struct LogicalAnd { + template + T operator()(T x, T y) { return x && y; }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { return x || y; }; +}; + template [[kernel]] void binary_op_s2s( device const T* a, @@ -377,3 +387,6 @@ instantiate_binary_all(naneq, float16, half, bool, NaNEqual) instantiate_binary_all(naneq, float32, float, bool, NaNEqual) instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual) instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) + +instantiate_binary_all(lor, bool_, bool, bool, LogicalOr) +instantiate_binary_all(land, bool_, bool, bool, LogicalAnd) \ No newline at end of file diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 94f1217a1..35f91ffcb 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -439,6 +439,20 @@ void LogicalNot::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "lnot"); } +void LogicalAnd::eval_gpu(const std::vector& inputs, array& out) { + binary_op( + inputs, + out, + "land"); // Assume "land" is the operation identifier for logical AND +} + +void LogicalOr::eval_gpu(const std::vector& inputs, array& out) { + binary_op( + inputs, + out, + "lor"); // Assume "lor" is the operation identifier for logical OR +} + void LogAddExp::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "lae"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index e2f92b093..4c4f1a76e 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -48,6 +48,8 @@ NO_GPU(Load) NO_GPU(Log) NO_GPU(Log1p) NO_GPU(LogicalNot) +NO_GPU(LogicalAnd) +NO_GPU(LogicalOr) NO_GPU(LogAddExp) NO_GPU(Matmul) NO_GPU(Maximum) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index aff0d46ed..4a2740d1f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1608,6 +1608,43 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) { {astype(a, bool_, s)}); } +array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) { + // Broadcast arrays to a common shape + auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); + + return array( + inputs[0].shape(), + bool_, + std::make_unique(to_stream(s)), + inputs); +} +array operator&&(const array& a, const array& b) { + // check if a and b are bool arrays + if (a.dtype() != bool_ || b.dtype() != bool_) { + throw std::invalid_argument("[operator&&] only supported for bool arrays."); + } + return logical_and(a, b); +} + +array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) { + // Broadcast arrays to a common shape + auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); + + return array( + inputs[0].shape(), + bool_, + std::make_unique(to_stream(s)), + inputs); +} +array operator||(const array& a, const array& b) { + // check if a and b are bool arrays + if (a.dtype() != bool_ || b.dtype() != bool_) { + throw std::invalid_argument( + "[operator||] is only supported for bool arrays."); + } + return logical_or(a, b); +} + array reciprocal(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); return divide(array(1.0f, dtype), a, to_stream(s)); diff --git a/mlx/ops.h b/mlx/ops.h index ae7776223..2b8061f64 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -668,6 +668,14 @@ array sign(const array& a, StreamOrDevice s = {}); /** Logical not of an array */ array logical_not(const array& a, StreamOrDevice s = {}); +/** Logical and of two arrays */ +array logical_and(const array& a, const array& b, StreamOrDevice s = {}); +array operator&&(const array& a, const array& b); + +/** Logical or of two arrays */ +array logical_or(const array& a, const array& b, StreamOrDevice s = {}); +array operator||(const array& a, const array& b); + /** The reciprocal (1/x) of the elements in an array. */ array reciprocal(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ffcaefb44..d2b3c23e2 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1338,6 +1338,64 @@ std::pair LogicalNot::vmap( return {logical_not(inputs[0], stream()), axes[0]}; } +std::vector LogicalAnd::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 2); + + return {zeros_like(cotan, stream()), zeros_like(cotan, stream())}; +} + +array LogicalAnd::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 2); + assert(argnums.size() <= 2); + + return zeros_like(primals[0], stream()); +} + +std::pair LogicalAnd::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 2); + assert(axes.size() == 2); + + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {logical_and(a, b, stream()), to_ax}; +} + +std::vector LogicalOr::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + assert(primals.size() == 2); + + return {zeros_like(cotan, stream()), zeros_like(cotan, stream())}; +} + +array LogicalOr::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 2); + assert(argnums.size() <= 2); + + return zeros_like(primals[0], stream()); +} + +std::pair LogicalOr::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 2); + assert(axes.size() == 2); + + auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + return {logical_or(a, b, stream()), to_ax}; +} + std::vector LogAddExp::vjp( const std::vector& primals, const array& cotan, diff --git a/mlx/primitives.h b/mlx/primitives.h index e87e934ab..c69b830a2 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -903,6 +903,44 @@ class LogicalNot : public Primitive { void eval(const std::vector& inputs, array& out); }; +class LogicalAnd : public Primitive { + public: + explicit LogicalAnd(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LogicalAnd) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class LogicalOr : public Primitive { + public: + explicit LogicalOr(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LogicalOr) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + class LogAddExp : public Primitive { public: explicit LogAddExp(Stream stream) : Primitive(stream){}; diff --git a/python/src/array.cpp b/python/src/array.cpp index dddbcce29..61b8a53f2 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -764,6 +764,18 @@ void init_array(py::module_& m) { return power(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "__and__", + [](const array& a, const ScalarOrArray v) { + return logical_and(a, to_array(v, a.dtype())); + }, + "other"_a) + .def( + "__or__", + [](const array& a, const ScalarOrArray v) { + return logical_or(a, to_array(v, a.dtype())); + }, + "other"_a) .def( "flatten", [](const array& a, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 50bee45a6..1b865fbd1 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -670,6 +670,51 @@ void init_ops(py::module_& m) { Returns: array: The boolean array containing the logical not of ``a``. )pbdoc"); + m.def( + "logical_and", + [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { + return logical_and(to_array(a), to_array(b), s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array + + Element-wise logical and. + + Args: + a (array): First input array or scalar. + b (array): Second input array or scalar. + + Returns: + array: The boolean array containing the logical and of ``a`` and ``b``. + )pbdoc"); + + m.def( + "logical_or", + [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { + return logical_or(to_array(a), to_array(b), s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array + + Element-wise logical or. + + Args: + a (array): First input array or scalar. + b (array): Second input array or scalar. + + Returns: + array: The boolean array containing the logical or of ``a`` and ``b``. + )pbdoc"); m.def( "logaddexp", [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 393c8dc9e..13f814fc1 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -610,6 +610,28 @@ class TestOps(mlx_tests.MLXTestCase): expected = np.logical_not(a) self.assertTrue(np.array_equal(result, expected)) + def test_logical_and(self): + a = mx.array([True, False, True, False]) + b = mx.array([True, True, False, False]) + result = mx.logical_and(a, b) + expected = np.logical_and(a, b) + self.assertTrue(np.array_equal(result, expected)) + + # test overloaded operator + result = a & b + self.assertTrue(np.array_equal(result, expected)) + + def test_logical_or(self): + a = mx.array([True, False, True, False]) + b = mx.array([True, True, False, False]) + result = mx.logical_or(a, b) + expected = np.logical_or(a, b) + self.assertTrue(np.array_equal(result, expected)) + + # test overloaded operator + result = a | b + self.assertTrue(np.array_equal(result, expected)) + def test_square(self): a = mx.array([0.1, 0.5, 1.0, 10.0]) result = mx.square(a) diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 63f616aa3..695ddd65f 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -80,6 +80,8 @@ class TestVmap(mlx_tests.MLXTestCase): "multiply", "power", "subtract", + "logical_or", + "logical_and", ] for opname in ops: with self.subTest(op=opname): diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index d34497d29..d1cd9552b 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -795,6 +795,44 @@ TEST_CASE("test arithmetic unary ops") { CHECK_EQ(y.item(), true); } + // Test logical and + { + array x(true); + array y(true); + CHECK_EQ(logical_and(x, y).item(), true); + + x = array(1.0f); + y = array(1.0f); + auto z = logical_and(x, y); + CHECK_EQ(z.dtype(), bool_); + CHECK_EQ(z.item(), true); + + x = array(0); + y = array(1.0f); + z = logical_and(x, y); + CHECK_EQ(z.dtype(), bool_); + CHECK_EQ(z.item(), false); + } + + // Test logical or + { + array x(false); + array y(false); + CHECK_EQ(logical_or(x, y).item(), false); + + x = array(1.0f); + y = array(1.0f); + auto z = logical_or(x, y); + CHECK_EQ(z.dtype(), bool_); + CHECK_EQ(z.item(), true); + + x = array(0); + y = array(1.0f); + z = logical_or(x, y); + CHECK_EQ(z.dtype(), bool_); + CHECK_EQ(z.item(), true); + } + // Test abs { array x({-1.0f, 0.0f, 1.0f});