Comparing python objects (such as list/tuple) with mlx.core.array (#920)

* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
AmirHossein_Razlighi 2024-03-29 17:22:30 +03:30 committed by GitHub
parent 913b19329c
commit f48bc496c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 143 additions and 8 deletions

View File

@ -15,7 +15,7 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`.
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>

View File

@ -820,37 +820,45 @@ void init_array(nb::module_& m) {
"other"_a)
.def(
"__eq__",
[](const array& a, const ScalarOrArray v) {
[](const array& a,
const ScalarOrArray& v) -> std::variant<array, bool> {
if (!is_comparable_with_array(v)) {
return false;
}
return equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__lt__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return less(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__le__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return less_equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__gt__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return greater(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ge__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return greater_equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ne__",
[](const array& a, const ScalarOrArray v) {
[](const array& a,
const ScalarOrArray v) -> std::variant<array, bool> {
if (!is_comparable_with_array(v)) {
return true;
}
return not_equal(a, to_array(v, a.dtype()));
},
"other"_a)
@ -1432,4 +1440,4 @@ void init_array(nb::module_& m) {
R"pbdoc(
Extract a diagonal or construct a diagonal matrix.
)pbdoc");
}
}

View File

@ -44,6 +44,18 @@ inline array to_array_with_accessor(nb::object obj) {
}
}
inline bool is_comparable_with_array(const ScalarOrArray& v) {
// Checks if the value can be compared to an array (or is already an
// mlx array)
if (auto pv = std::get_if<nb::object>(&v); pv) {
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
} else {
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
// and can be compared to an array
return true;
}
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {

View File

@ -89,6 +89,121 @@ class TestDtypes(mlx_tests.MLXTestCase):
self.assertListEqual(list(z.shape), list(y.shape))
class TestEquality(mlx_tests.MLXTestCase):
def test_array_eq_array(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3])
c = mx.array([1, 2, 4])
self.assertTrue(mx.all(a == b))
self.assertFalse(mx.all(a == c))
def test_array_eq_scalar(self):
a = mx.array([1, 2, 3])
b = 1
c = 4
d = 2.5
e = mx.array([1, 2.5, 3.25])
self.assertTrue(mx.any(a == b))
self.assertFalse(mx.all(a == c))
self.assertFalse(mx.all(a == d))
self.assertTrue(mx.any(a == e))
def test_list_equals_array(self):
a = mx.array([1, 2, 3])
b = [1, 2, 3]
c = [1, 2, 4]
# mlx array equality returns false if is compared with any kind of
# object which is not an mlx array
self.assertFalse(a == b)
self.assertFalse(a == c)
def test_tuple_equals_array(self):
a = mx.array([1, 2, 3])
b = (1, 2, 3)
c = (1, 2, 4)
# mlx array equality returns false if is compared with any kind of
# object which is not an mlx array
self.assertFalse(a == b)
self.assertFalse(a == c)
class TestInequality(mlx_tests.MLXTestCase):
def test_array_ne_array(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3])
c = mx.array([1, 2, 4])
self.assertFalse(mx.any(a != b))
self.assertTrue(mx.any(a != c))
def test_array_ne_scalar(self):
a = mx.array([1, 2, 3])
b = 1
c = 4
d = 1.5
e = 2.5
f = mx.array([1, 2.5, 3.25])
self.assertFalse(mx.all(a != b))
self.assertTrue(mx.any(a != c))
self.assertTrue(mx.any(a != d))
self.assertTrue(mx.any(a != e))
self.assertFalse(mx.all(a != f))
def test_list_not_equals_array(self):
a = mx.array([1, 2, 3])
b = [1, 2, 3]
c = [1, 2, 4]
# mlx array inequality returns true if is compared with any kind of
# object which is not an mlx array
self.assertTrue(a != b)
self.assertTrue(a != c)
def test_tuple_not_equals_array(self):
a = mx.array([1, 2, 3])
b = (1, 2, 3)
c = (1, 2, 4)
# mlx array inequality returns true if is compared with any kind of
# object which is not an mlx array
self.assertTrue(a != b)
self.assertTrue(a != c)
def test_obj_inequality_array(self):
str_ = "hello"
a = mx.array([1, 2, 3])
lst_ = [1, 2, 3]
tpl_ = (1, 2, 3)
# check if object comparison(</>/<=/>=) with mlx array should throw an exception
# if not, the tests will fail
with self.assertRaises(ValueError):
a < str_
with self.assertRaises(ValueError):
a > str_
with self.assertRaises(ValueError):
a <= str_
with self.assertRaises(ValueError):
a >= str_
with self.assertRaises(ValueError):
a < lst_
with self.assertRaises(ValueError):
a > lst_
with self.assertRaises(ValueError):
a <= lst_
with self.assertRaises(ValueError):
a >= lst_
with self.assertRaises(ValueError):
a < tpl_
with self.assertRaises(ValueError):
a > tpl_
with self.assertRaises(ValueError):
a <= tpl_
with self.assertRaises(ValueError):
a >= tpl_
class TestArray(mlx_tests.MLXTestCase):
def test_array_basics(self):
x = mx.array(1)