From f48bc496c7a164c8947845249a227d01496b6554 Mon Sep 17 00:00:00 2001 From: AmirHossein_Razlighi <79264971+amirhossein-razlighi@users.noreply.github.com> Date: Fri, 29 Mar 2024 17:22:30 +0330 Subject: [PATCH] Comparing python objects (such as list/tuple) with `mlx.core.array` (#920) * add implicit conversion of list to array for equality constraint * add tests for array equality * add test for tuple and array equality * return False if __eq__ arg is list or tuple * write tests for equality * update the rule of comparison for __ge__/__gt__/__lt__/__le__ * add a helper function for detecting mlx.core.array * return true in case fo inequality * debug minor issue regarding detecting mlx array * add tests for inequality comparisons * add name for contribution * reformat files using pre-commit * update tests for float * update tests for inequality * raise exception in case of invalid comparisons * use isinstance instead of string comparison * replace "is_convirtable_to_array" with previous logic * remove throwing exceptions for other operations * just a comment * minor changes for efficiency * optimize a utils function * change the function name * Update ACKNOWLEDGMENTS.md --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 2 +- python/src/array.cpp | 22 ++++--- python/src/utils.h | 12 ++++ python/tests/test_array.py | 115 +++++++++++++++++++++++++++++++++++++ 4 files changed, 143 insertions(+), 8 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 0e5d1142d..784dc6329 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -15,7 +15,7 @@ MLX was developed with contributions from the following individuals: - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` -- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. +- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. diff --git a/python/src/array.cpp b/python/src/array.cpp index de57b701e..b405faf84 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -820,37 +820,45 @@ void init_array(nb::module_& m) { "other"_a) .def( "__eq__", - [](const array& a, const ScalarOrArray v) { + [](const array& a, + const ScalarOrArray& v) -> std::variant { + if (!is_comparable_with_array(v)) { + return false; + } return equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__lt__", - [](const array& a, const ScalarOrArray v) { + [](const array& a, const ScalarOrArray v) -> array { return less(a, to_array(v, a.dtype())); }, "other"_a) .def( "__le__", - [](const array& a, const ScalarOrArray v) { + [](const array& a, const ScalarOrArray v) -> array { return less_equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__gt__", - [](const array& a, const ScalarOrArray v) { + [](const array& a, const ScalarOrArray v) -> array { return greater(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ge__", - [](const array& a, const ScalarOrArray v) { + [](const array& a, const ScalarOrArray v) -> array { return greater_equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ne__", - [](const array& a, const ScalarOrArray v) { + [](const array& a, + const ScalarOrArray v) -> std::variant { + if (!is_comparable_with_array(v)) { + return true; + } return not_equal(a, to_array(v, a.dtype())); }, "other"_a) @@ -1432,4 +1440,4 @@ void init_array(nb::module_& m) { R"pbdoc( Extract a diagonal or construct a diagonal matrix. )pbdoc"); -} +} \ No newline at end of file diff --git a/python/src/utils.h b/python/src/utils.h index a320412bd..41b422b98 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -44,6 +44,18 @@ inline array to_array_with_accessor(nb::object obj) { } } +inline bool is_comparable_with_array(const ScalarOrArray& v) { + // Checks if the value can be compared to an array (or is already an + // mlx array) + if (auto pv = std::get_if(&v); pv) { + return nb::isinstance(*pv) || nb::hasattr(*pv, "__mlx_array__"); + } else { + // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.) + // and can be compared to an array + return true; + } +} + inline array to_array( const ScalarOrArray& v, std::optional dtype = std::nullopt) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 6094e025f..4e8c00134 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -89,6 +89,121 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertListEqual(list(z.shape), list(y.shape)) +class TestEquality(mlx_tests.MLXTestCase): + def test_array_eq_array(self): + a = mx.array([1, 2, 3]) + b = mx.array([1, 2, 3]) + c = mx.array([1, 2, 4]) + self.assertTrue(mx.all(a == b)) + self.assertFalse(mx.all(a == c)) + + def test_array_eq_scalar(self): + a = mx.array([1, 2, 3]) + b = 1 + c = 4 + d = 2.5 + e = mx.array([1, 2.5, 3.25]) + self.assertTrue(mx.any(a == b)) + self.assertFalse(mx.all(a == c)) + self.assertFalse(mx.all(a == d)) + self.assertTrue(mx.any(a == e)) + + def test_list_equals_array(self): + a = mx.array([1, 2, 3]) + b = [1, 2, 3] + c = [1, 2, 4] + + # mlx array equality returns false if is compared with any kind of + # object which is not an mlx array + self.assertFalse(a == b) + self.assertFalse(a == c) + + def test_tuple_equals_array(self): + a = mx.array([1, 2, 3]) + b = (1, 2, 3) + c = (1, 2, 4) + + # mlx array equality returns false if is compared with any kind of + # object which is not an mlx array + self.assertFalse(a == b) + self.assertFalse(a == c) + + +class TestInequality(mlx_tests.MLXTestCase): + def test_array_ne_array(self): + a = mx.array([1, 2, 3]) + b = mx.array([1, 2, 3]) + c = mx.array([1, 2, 4]) + self.assertFalse(mx.any(a != b)) + self.assertTrue(mx.any(a != c)) + + def test_array_ne_scalar(self): + a = mx.array([1, 2, 3]) + b = 1 + c = 4 + d = 1.5 + e = 2.5 + f = mx.array([1, 2.5, 3.25]) + self.assertFalse(mx.all(a != b)) + self.assertTrue(mx.any(a != c)) + self.assertTrue(mx.any(a != d)) + self.assertTrue(mx.any(a != e)) + self.assertFalse(mx.all(a != f)) + + def test_list_not_equals_array(self): + a = mx.array([1, 2, 3]) + b = [1, 2, 3] + c = [1, 2, 4] + + # mlx array inequality returns true if is compared with any kind of + # object which is not an mlx array + self.assertTrue(a != b) + self.assertTrue(a != c) + + def test_tuple_not_equals_array(self): + a = mx.array([1, 2, 3]) + b = (1, 2, 3) + c = (1, 2, 4) + + # mlx array inequality returns true if is compared with any kind of + # object which is not an mlx array + self.assertTrue(a != b) + self.assertTrue(a != c) + + def test_obj_inequality_array(self): + str_ = "hello" + a = mx.array([1, 2, 3]) + lst_ = [1, 2, 3] + tpl_ = (1, 2, 3) + + # check if object comparison(/<=/>=) with mlx array should throw an exception + # if not, the tests will fail + with self.assertRaises(ValueError): + a < str_ + with self.assertRaises(ValueError): + a > str_ + with self.assertRaises(ValueError): + a <= str_ + with self.assertRaises(ValueError): + a >= str_ + with self.assertRaises(ValueError): + a < lst_ + with self.assertRaises(ValueError): + a > lst_ + with self.assertRaises(ValueError): + a <= lst_ + with self.assertRaises(ValueError): + a >= lst_ + with self.assertRaises(ValueError): + a < tpl_ + with self.assertRaises(ValueError): + a > tpl_ + with self.assertRaises(ValueError): + a <= tpl_ + with self.assertRaises(ValueError): + a >= tpl_ + + class TestArray(mlx_tests.MLXTestCase): def test_array_basics(self): x = mx.array(1)