Comparing python objects (such as list/tuple) with mlx.core.array (#920)

* add implicit conversion of list to array for equality constraint

* add tests for array equality

* add test for tuple and array equality

* return False if __eq__ arg is list or tuple

* write tests for equality

* update the rule of comparison for __ge__/__gt__/__lt__/__le__

* add a helper function for detecting mlx.core.array

* return true in case fo inequality

* debug minor issue regarding detecting mlx array

* add tests for inequality comparisons

* add name for contribution

* reformat files using pre-commit

* update tests for float

* update tests for inequality

* raise exception in case of invalid comparisons

* use isinstance instead of string comparison

* replace "is_convirtable_to_array" with previous logic

* remove throwing exceptions for other operations

* just a comment

* minor changes for efficiency

* optimize a utils function

* change the function name

* Update ACKNOWLEDGMENTS.md

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
AmirHossein_Razlighi
2024-03-29 17:22:30 +03:30
committed by GitHub
parent 913b19329c
commit f48bc496c7
4 changed files with 143 additions and 8 deletions

View File

@@ -44,6 +44,18 @@ inline array to_array_with_accessor(nb::object obj) {
}
}
inline bool is_comparable_with_array(const ScalarOrArray& v) {
// Checks if the value can be compared to an array (or is already an
// mlx array)
if (auto pv = std::get_if<nb::object>(&v); pv) {
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
} else {
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
// and can be compared to an array
return true;
}
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {