Comparing python objects (such as list/tuple) with mlx.core.array (#920)

* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
AmirHossein_Razlighi
2024-03-29 17:22:30 +03:30
committed by GitHub
parent 913b19329c
commit f48bc496c7
4 changed files with 143 additions and 8 deletions

View File

@@ -820,37 +820,45 @@ void init_array(nb::module_& m) {
"other"_a)
.def(
"__eq__",
[](const array& a, const ScalarOrArray v) {
[](const array& a,
const ScalarOrArray& v) -> std::variant<array, bool> {
if (!is_comparable_with_array(v)) {
return false;
}
return equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__lt__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return less(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__le__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return less_equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__gt__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return greater(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ge__",
[](const array& a, const ScalarOrArray v) {
[](const array& a, const ScalarOrArray v) -> array {
return greater_equal(a, to_array(v, a.dtype()));
},
"other"_a)
.def(
"__ne__",
[](const array& a, const ScalarOrArray v) {
[](const array& a,
const ScalarOrArray v) -> std::variant<array, bool> {
if (!is_comparable_with_array(v)) {
return true;
}
return not_equal(a, to_array(v, a.dtype()));
},
"other"_a)
@@ -1432,4 +1440,4 @@ void init_array(nb::module_& m) {
R"pbdoc(
Extract a diagonal or construct a diagonal matrix.
)pbdoc");
}
}

View File

@@ -44,6 +44,18 @@ inline array to_array_with_accessor(nb::object obj) {
}
}
inline bool is_comparable_with_array(const ScalarOrArray& v) {
// Checks if the value can be compared to an array (or is already an
// mlx array)
if (auto pv = std::get_if<nb::object>(&v); pv) {
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
} else {
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
// and can be compared to an array
return true;
}
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {