mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Comparing python objects (such as list/tuple) with mlx.core.array
(#920)
* add implicit conversion of list to array for equality constraint * add tests for array equality * add test for tuple and array equality * return False if __eq__ arg is list or tuple * write tests for equality * update the rule of comparison for __ge__/__gt__/__lt__/__le__ * add a helper function for detecting mlx.core.array * return true in case fo inequality * debug minor issue regarding detecting mlx array * add tests for inequality comparisons * add name for contribution * reformat files using pre-commit * update tests for float * update tests for inequality * raise exception in case of invalid comparisons * use isinstance instead of string comparison * replace "is_convirtable_to_array" with previous logic * remove throwing exceptions for other operations * just a comment * minor changes for efficiency * optimize a utils function * change the function name * Update ACKNOWLEDGMENTS.md --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:

committed by
GitHub

parent
913b19329c
commit
f48bc496c7
@@ -820,37 +820,45 @@ void init_array(nb::module_& m) {
|
||||
"other"_a)
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
[](const array& a,
|
||||
const ScalarOrArray& v) -> std::variant<array, bool> {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
return false;
|
||||
}
|
||||
return equal(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__lt__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
return less(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__le__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
return less_equal(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__gt__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
return greater(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ge__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
[](const array& a, const ScalarOrArray v) -> array {
|
||||
return greater_equal(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
.def(
|
||||
"__ne__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
[](const array& a,
|
||||
const ScalarOrArray v) -> std::variant<array, bool> {
|
||||
if (!is_comparable_with_array(v)) {
|
||||
return true;
|
||||
}
|
||||
return not_equal(a, to_array(v, a.dtype()));
|
||||
},
|
||||
"other"_a)
|
||||
@@ -1432,4 +1440,4 @@ void init_array(nb::module_& m) {
|
||||
R"pbdoc(
|
||||
Extract a diagonal or construct a diagonal matrix.
|
||||
)pbdoc");
|
||||
}
|
||||
}
|
@@ -44,6 +44,18 @@ inline array to_array_with_accessor(nb::object obj) {
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
// Checks if the value can be compared to an array (or is already an
|
||||
// mlx array)
|
||||
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
||||
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
||||
} else {
|
||||
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
||||
// and can be compared to an array
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
inline array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt) {
|
||||
|
@@ -89,6 +89,121 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(z.shape), list(y.shape))
|
||||
|
||||
|
||||
class TestEquality(mlx_tests.MLXTestCase):
|
||||
def test_array_eq_array(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3])
|
||||
c = mx.array([1, 2, 4])
|
||||
self.assertTrue(mx.all(a == b))
|
||||
self.assertFalse(mx.all(a == c))
|
||||
|
||||
def test_array_eq_scalar(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = 1
|
||||
c = 4
|
||||
d = 2.5
|
||||
e = mx.array([1, 2.5, 3.25])
|
||||
self.assertTrue(mx.any(a == b))
|
||||
self.assertFalse(mx.all(a == c))
|
||||
self.assertFalse(mx.all(a == d))
|
||||
self.assertTrue(mx.any(a == e))
|
||||
|
||||
def test_list_equals_array(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = [1, 2, 3]
|
||||
c = [1, 2, 4]
|
||||
|
||||
# mlx array equality returns false if is compared with any kind of
|
||||
# object which is not an mlx array
|
||||
self.assertFalse(a == b)
|
||||
self.assertFalse(a == c)
|
||||
|
||||
def test_tuple_equals_array(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = (1, 2, 3)
|
||||
c = (1, 2, 4)
|
||||
|
||||
# mlx array equality returns false if is compared with any kind of
|
||||
# object which is not an mlx array
|
||||
self.assertFalse(a == b)
|
||||
self.assertFalse(a == c)
|
||||
|
||||
|
||||
class TestInequality(mlx_tests.MLXTestCase):
|
||||
def test_array_ne_array(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3])
|
||||
c = mx.array([1, 2, 4])
|
||||
self.assertFalse(mx.any(a != b))
|
||||
self.assertTrue(mx.any(a != c))
|
||||
|
||||
def test_array_ne_scalar(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = 1
|
||||
c = 4
|
||||
d = 1.5
|
||||
e = 2.5
|
||||
f = mx.array([1, 2.5, 3.25])
|
||||
self.assertFalse(mx.all(a != b))
|
||||
self.assertTrue(mx.any(a != c))
|
||||
self.assertTrue(mx.any(a != d))
|
||||
self.assertTrue(mx.any(a != e))
|
||||
self.assertFalse(mx.all(a != f))
|
||||
|
||||
def test_list_not_equals_array(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = [1, 2, 3]
|
||||
c = [1, 2, 4]
|
||||
|
||||
# mlx array inequality returns true if is compared with any kind of
|
||||
# object which is not an mlx array
|
||||
self.assertTrue(a != b)
|
||||
self.assertTrue(a != c)
|
||||
|
||||
def test_tuple_not_equals_array(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = (1, 2, 3)
|
||||
c = (1, 2, 4)
|
||||
|
||||
# mlx array inequality returns true if is compared with any kind of
|
||||
# object which is not an mlx array
|
||||
self.assertTrue(a != b)
|
||||
self.assertTrue(a != c)
|
||||
|
||||
def test_obj_inequality_array(self):
|
||||
str_ = "hello"
|
||||
a = mx.array([1, 2, 3])
|
||||
lst_ = [1, 2, 3]
|
||||
tpl_ = (1, 2, 3)
|
||||
|
||||
# check if object comparison(</>/<=/>=) with mlx array should throw an exception
|
||||
# if not, the tests will fail
|
||||
with self.assertRaises(ValueError):
|
||||
a < str_
|
||||
with self.assertRaises(ValueError):
|
||||
a > str_
|
||||
with self.assertRaises(ValueError):
|
||||
a <= str_
|
||||
with self.assertRaises(ValueError):
|
||||
a >= str_
|
||||
with self.assertRaises(ValueError):
|
||||
a < lst_
|
||||
with self.assertRaises(ValueError):
|
||||
a > lst_
|
||||
with self.assertRaises(ValueError):
|
||||
a <= lst_
|
||||
with self.assertRaises(ValueError):
|
||||
a >= lst_
|
||||
with self.assertRaises(ValueError):
|
||||
a < tpl_
|
||||
with self.assertRaises(ValueError):
|
||||
a > tpl_
|
||||
with self.assertRaises(ValueError):
|
||||
a <= tpl_
|
||||
with self.assertRaises(ValueError):
|
||||
a >= tpl_
|
||||
|
||||
|
||||
class TestArray(mlx_tests.MLXTestCase):
|
||||
def test_array_basics(self):
|
||||
x = mx.array(1)
|
||||
|
Reference in New Issue
Block a user