diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
index 0e5d1142d..784dc6329 100644
--- a/ACKNOWLEDGMENTS.md
+++ b/ACKNOWLEDGMENTS.md
@@ -15,7 +15,7 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
-- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`.
+- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`.
diff --git a/python/src/array.cpp b/python/src/array.cpp
index de57b701e..b405faf84 100644
--- a/python/src/array.cpp
+++ b/python/src/array.cpp
@@ -820,37 +820,45 @@ void init_array(nb::module_& m) {
"other"_a)
.def(
"__eq__",
- [](const array& a, const ScalarOrArray v) {
+ [](const array& a,
+ const ScalarOrArray& v) -> std::variant {
+ if (!is_comparable_with_array(v)) {
+ return false;
+ }
return equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__lt__",
- [](const array& a, const ScalarOrArray v) {
+ [](const array& a, const ScalarOrArray v) -> array {
return less(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__le__",
- [](const array& a, const ScalarOrArray v) {
+ [](const array& a, const ScalarOrArray v) -> array {
return less_equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__gt__",
- [](const array& a, const ScalarOrArray v) {
+ [](const array& a, const ScalarOrArray v) -> array {
return greater(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ge__",
- [](const array& a, const ScalarOrArray v) {
+ [](const array& a, const ScalarOrArray v) -> array {
return greater_equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ne__",
- [](const array& a, const ScalarOrArray v) {
+ [](const array& a,
+ const ScalarOrArray v) -> std::variant {
+ if (!is_comparable_with_array(v)) {
+ return true;
+ }
return not_equal(a, to_array(v, a.dtype()));
},
"other"_a)
@@ -1432,4 +1440,4 @@ void init_array(nb::module_& m) {
R"pbdoc(
Extract a diagonal or construct a diagonal matrix.
)pbdoc");
-}
+}
\ No newline at end of file
diff --git a/python/src/utils.h b/python/src/utils.h
index a320412bd..41b422b98 100644
--- a/python/src/utils.h
+++ b/python/src/utils.h
@@ -44,6 +44,18 @@ inline array to_array_with_accessor(nb::object obj) {
}
}
+inline bool is_comparable_with_array(const ScalarOrArray& v) {
+ // Checks if the value can be compared to an array (or is already an
+ // mlx array)
+ if (auto pv = std::get_if(&v); pv) {
+ return nb::isinstance(*pv) || nb::hasattr(*pv, "__mlx_array__");
+ } else {
+ // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
+ // and can be compared to an array
+ return true;
+ }
+}
+
inline array to_array(
const ScalarOrArray& v,
std::optional dtype = std::nullopt) {
diff --git a/python/tests/test_array.py b/python/tests/test_array.py
index 6094e025f..4e8c00134 100644
--- a/python/tests/test_array.py
+++ b/python/tests/test_array.py
@@ -89,6 +89,121 @@ class TestDtypes(mlx_tests.MLXTestCase):
self.assertListEqual(list(z.shape), list(y.shape))
+class TestEquality(mlx_tests.MLXTestCase):
+ def test_array_eq_array(self):
+ a = mx.array([1, 2, 3])
+ b = mx.array([1, 2, 3])
+ c = mx.array([1, 2, 4])
+ self.assertTrue(mx.all(a == b))
+ self.assertFalse(mx.all(a == c))
+
+ def test_array_eq_scalar(self):
+ a = mx.array([1, 2, 3])
+ b = 1
+ c = 4
+ d = 2.5
+ e = mx.array([1, 2.5, 3.25])
+ self.assertTrue(mx.any(a == b))
+ self.assertFalse(mx.all(a == c))
+ self.assertFalse(mx.all(a == d))
+ self.assertTrue(mx.any(a == e))
+
+ def test_list_equals_array(self):
+ a = mx.array([1, 2, 3])
+ b = [1, 2, 3]
+ c = [1, 2, 4]
+
+ # mlx array equality returns false if is compared with any kind of
+ # object which is not an mlx array
+ self.assertFalse(a == b)
+ self.assertFalse(a == c)
+
+ def test_tuple_equals_array(self):
+ a = mx.array([1, 2, 3])
+ b = (1, 2, 3)
+ c = (1, 2, 4)
+
+ # mlx array equality returns false if is compared with any kind of
+ # object which is not an mlx array
+ self.assertFalse(a == b)
+ self.assertFalse(a == c)
+
+
+class TestInequality(mlx_tests.MLXTestCase):
+ def test_array_ne_array(self):
+ a = mx.array([1, 2, 3])
+ b = mx.array([1, 2, 3])
+ c = mx.array([1, 2, 4])
+ self.assertFalse(mx.any(a != b))
+ self.assertTrue(mx.any(a != c))
+
+ def test_array_ne_scalar(self):
+ a = mx.array([1, 2, 3])
+ b = 1
+ c = 4
+ d = 1.5
+ e = 2.5
+ f = mx.array([1, 2.5, 3.25])
+ self.assertFalse(mx.all(a != b))
+ self.assertTrue(mx.any(a != c))
+ self.assertTrue(mx.any(a != d))
+ self.assertTrue(mx.any(a != e))
+ self.assertFalse(mx.all(a != f))
+
+ def test_list_not_equals_array(self):
+ a = mx.array([1, 2, 3])
+ b = [1, 2, 3]
+ c = [1, 2, 4]
+
+ # mlx array inequality returns true if is compared with any kind of
+ # object which is not an mlx array
+ self.assertTrue(a != b)
+ self.assertTrue(a != c)
+
+ def test_tuple_not_equals_array(self):
+ a = mx.array([1, 2, 3])
+ b = (1, 2, 3)
+ c = (1, 2, 4)
+
+ # mlx array inequality returns true if is compared with any kind of
+ # object which is not an mlx array
+ self.assertTrue(a != b)
+ self.assertTrue(a != c)
+
+ def test_obj_inequality_array(self):
+ str_ = "hello"
+ a = mx.array([1, 2, 3])
+ lst_ = [1, 2, 3]
+ tpl_ = (1, 2, 3)
+
+ # check if object comparison(>/<=/>=) with mlx array should throw an exception
+ # if not, the tests will fail
+ with self.assertRaises(ValueError):
+ a < str_
+ with self.assertRaises(ValueError):
+ a > str_
+ with self.assertRaises(ValueError):
+ a <= str_
+ with self.assertRaises(ValueError):
+ a >= str_
+ with self.assertRaises(ValueError):
+ a < lst_
+ with self.assertRaises(ValueError):
+ a > lst_
+ with self.assertRaises(ValueError):
+ a <= lst_
+ with self.assertRaises(ValueError):
+ a >= lst_
+ with self.assertRaises(ValueError):
+ a < tpl_
+ with self.assertRaises(ValueError):
+ a > tpl_
+ with self.assertRaises(ValueError):
+ a <= tpl_
+ with self.assertRaises(ValueError):
+ a >= tpl_
+
+
class TestArray(mlx_tests.MLXTestCase):
def test_array_basics(self):
x = mx.array(1)