mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR * run pre-commit * Refactor logical_and and logical_or functions * Add acknowledgement * Add logical AND and logical OR operators * Refactor logical_and and logical_or functions * Add support for logical operators on bool arrays * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Add logical AND and OR operators for arrays and scalars * Refactor vjp and jvp methods in primitives.cpp * Add overloaded operators for logical AND and OR * format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
022a944367
commit
73321b8097
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
|
||||
|
@ -60,6 +60,8 @@ Operations
|
||||
log1p
|
||||
logaddexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
logsumexp
|
||||
matmul
|
||||
max
|
||||
|
@ -41,6 +41,8 @@ DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
|
@ -57,6 +57,8 @@ DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/erf.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
@ -364,6 +365,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
unary(in, out, [](auto x) { return !x; });
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -131,6 +131,16 @@ struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x && y; };
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x || y; };
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
@ -377,3 +387,6 @@ instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
@ -439,6 +439,20 @@ void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "lnot");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(
|
||||
inputs,
|
||||
out,
|
||||
"land"); // Assume "land" is the operation identifier for logical AND
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(
|
||||
inputs,
|
||||
out,
|
||||
"lor"); // Assume "lor" is the operation identifier for logical OR
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
|
@ -48,6 +48,8 @@ NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
|
37
mlx/ops.cpp
37
mlx/ops.cpp
@ -1608,6 +1608,43 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
||||
{astype(a, bool_, s)});
|
||||
}
|
||||
|
||||
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Broadcast arrays to a common shape
|
||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||
|
||||
return array(
|
||||
inputs[0].shape(),
|
||||
bool_,
|
||||
std::make_unique<LogicalAnd>(to_stream(s)),
|
||||
inputs);
|
||||
}
|
||||
array operator&&(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
}
|
||||
|
||||
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Broadcast arrays to a common shape
|
||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||
|
||||
return array(
|
||||
inputs[0].shape(),
|
||||
bool_,
|
||||
std::make_unique<LogicalOr>(to_stream(s)),
|
||||
inputs);
|
||||
}
|
||||
array operator||(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"[operator||] is only supported for bool arrays.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
}
|
||||
|
||||
array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return divide(array(1.0f, dtype), a, to_stream(s));
|
||||
|
@ -668,6 +668,14 @@ array sign(const array& a, StreamOrDevice s = {});
|
||||
/** Logical not of an array */
|
||||
array logical_not(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Logical and of two arrays */
|
||||
array logical_and(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator&&(const array& a, const array& b);
|
||||
|
||||
/** Logical or of two arrays */
|
||||
array logical_or(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator||(const array& a, const array& b);
|
||||
|
||||
/** The reciprocal (1/x) of the elements in an array. */
|
||||
array reciprocal(const array& a, StreamOrDevice s = {});
|
||||
|
||||
|
@ -1338,6 +1338,64 @@ std::pair<array, int> LogicalNot::vmap(
|
||||
return {logical_not(inputs[0], stream()), axes[0]};
|
||||
}
|
||||
|
||||
std::vector<array> LogicalAnd::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
|
||||
return {zeros_like(cotan, stream()), zeros_like(cotan, stream())};
|
||||
}
|
||||
|
||||
array LogicalAnd::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() <= 2);
|
||||
|
||||
return zeros_like(primals[0], stream());
|
||||
}
|
||||
|
||||
std::pair<array, int> LogicalAnd::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 2);
|
||||
assert(axes.size() == 2);
|
||||
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {logical_and(a, b, stream()), to_ax};
|
||||
}
|
||||
|
||||
std::vector<array> LogicalOr::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
|
||||
return {zeros_like(cotan, stream()), zeros_like(cotan, stream())};
|
||||
}
|
||||
|
||||
array LogicalOr::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 2);
|
||||
assert(argnums.size() <= 2);
|
||||
|
||||
return zeros_like(primals[0], stream());
|
||||
}
|
||||
|
||||
std::pair<array, int> LogicalOr::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 2);
|
||||
assert(axes.size() == 2);
|
||||
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {logical_or(a, b, stream()), to_ax};
|
||||
}
|
||||
|
||||
std::vector<array> LogAddExp::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
|
@ -903,6 +903,44 @@ class LogicalNot : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LogicalAnd : public Primitive {
|
||||
public:
|
||||
explicit LogicalAnd(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogicalAnd)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LogicalOr : public Primitive {
|
||||
public:
|
||||
explicit LogicalOr(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(LogicalOr)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class LogAddExp : public Primitive {
|
||||
public:
|
||||
explicit LogAddExp(Stream stream) : Primitive(stream){};
|
||||
|
@ -764,6 +764,18 @@ void init_array(py::module_& m) {
|
||||
return power(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__and__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return logical_and(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__or__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
return logical_or(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"flatten",
|
||||
[](const array& a,
|
||||
|
@ -670,6 +670,51 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The boolean array containing the logical not of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"logical_and",
|
||||
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
|
||||
return logical_and(to_array(a), to_array(b), s);
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise logical and.
|
||||
|
||||
Args:
|
||||
a (array): First input array or scalar.
|
||||
b (array): Second input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The boolean array containing the logical and of ``a`` and ``b``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"logical_or",
|
||||
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
|
||||
return logical_or(to_array(a), to_array(b), s);
|
||||
},
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise logical or.
|
||||
|
||||
Args:
|
||||
a (array): First input array or scalar.
|
||||
b (array): Second input array or scalar.
|
||||
|
||||
Returns:
|
||||
array: The boolean array containing the logical or of ``a`` and ``b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"logaddexp",
|
||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||
|
@ -610,6 +610,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = np.logical_not(a)
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
def test_logical_and(self):
|
||||
a = mx.array([True, False, True, False])
|
||||
b = mx.array([True, True, False, False])
|
||||
result = mx.logical_and(a, b)
|
||||
expected = np.logical_and(a, b)
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
# test overloaded operator
|
||||
result = a & b
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
def test_logical_or(self):
|
||||
a = mx.array([True, False, True, False])
|
||||
b = mx.array([True, True, False, False])
|
||||
result = mx.logical_or(a, b)
|
||||
expected = np.logical_or(a, b)
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
# test overloaded operator
|
||||
result = a | b
|
||||
self.assertTrue(np.array_equal(result, expected))
|
||||
|
||||
def test_square(self):
|
||||
a = mx.array([0.1, 0.5, 1.0, 10.0])
|
||||
result = mx.square(a)
|
||||
|
@ -80,6 +80,8 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
"multiply",
|
||||
"power",
|
||||
"subtract",
|
||||
"logical_or",
|
||||
"logical_and",
|
||||
]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
|
@ -795,6 +795,44 @@ TEST_CASE("test arithmetic unary ops") {
|
||||
CHECK_EQ(y.item<bool>(), true);
|
||||
}
|
||||
|
||||
// Test logical and
|
||||
{
|
||||
array x(true);
|
||||
array y(true);
|
||||
CHECK_EQ(logical_and(x, y).item<bool>(), true);
|
||||
|
||||
x = array(1.0f);
|
||||
y = array(1.0f);
|
||||
auto z = logical_and(x, y);
|
||||
CHECK_EQ(z.dtype(), bool_);
|
||||
CHECK_EQ(z.item<bool>(), true);
|
||||
|
||||
x = array(0);
|
||||
y = array(1.0f);
|
||||
z = logical_and(x, y);
|
||||
CHECK_EQ(z.dtype(), bool_);
|
||||
CHECK_EQ(z.item<bool>(), false);
|
||||
}
|
||||
|
||||
// Test logical or
|
||||
{
|
||||
array x(false);
|
||||
array y(false);
|
||||
CHECK_EQ(logical_or(x, y).item<bool>(), false);
|
||||
|
||||
x = array(1.0f);
|
||||
y = array(1.0f);
|
||||
auto z = logical_or(x, y);
|
||||
CHECK_EQ(z.dtype(), bool_);
|
||||
CHECK_EQ(z.item<bool>(), true);
|
||||
|
||||
x = array(0);
|
||||
y = array(1.0f);
|
||||
z = logical_or(x, y);
|
||||
CHECK_EQ(z.dtype(), bool_);
|
||||
CHECK_EQ(z.item<bool>(), true);
|
||||
}
|
||||
|
||||
// Test abs
|
||||
{
|
||||
array x({-1.0f, 0.0f, 1.0f});
|
||||
|
Loading…
Reference in New Issue
Block a user