[Fix] mx.allclose bug with infinite values (#539)

* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Rifur13
2024-01-25 23:47:06 -05:00
committed by GitHub
parent 87b7fa9ba2
commit 2463496471
5 changed files with 143 additions and 11 deletions

View File

@@ -566,7 +566,7 @@ void init_ops(py::module_& m) {
Args:
a (array): Input array or scalar.
b (array): Input array or scalar.
equal_nan (bool): If ``True``, NaNs are treated as equal.
equal_nan (bool): If ``True``, NaNs are considered equal.
Defaults to ``False``.
Returns:
@@ -1648,12 +1648,15 @@ void init_ops(py::module_& m) {
"rtol"_a = 1e-5,
"atol"_a = 1e-8,
py::kw_only(),
"equal_nan"_a = false,
"stream"_a = none,
R"pbdoc(
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
Approximate comparison of two arrays.
Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``.
The arrays are considered equal if:
.. code-block::
@@ -1668,6 +1671,47 @@ void init_ops(py::module_& m) {
b (array): Input array.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
equal_nan (bool): If ``True``, NaNs are considered equal.
Defaults to ``False``.
Returns:
array: The boolean output scalar indicating if the arrays are close.
)pbdoc");
m.def(
"isclose",
&isclose,
"a"_a,
"b"_a,
py::pos_only(),
"rtol"_a = 1e-5,
"atol"_a = 1e-8,
py::kw_only(),
"equal_nan"_a = false,
"stream"_a = none,
R"pbdoc(
isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
Returns a boolean array where two arrays are element-wise equal within a tolerance.
Infinite values are considered equal if they have the same sign, NaN values are
not equal unless ``equal_nan`` is ``True``.
Two values are considered equal if:
.. code-block::
abs(a - b) <= (atol + rtol * abs(b))
Note unlike :func:`array_equal`, this function supports numpy-style
broadcasting.
Args:
a (array): Input array.
b (array): Input array.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
equal_nan (bool): If ``True``, NaNs are considered equal.
Defaults to ``False``.
Returns:
array: The boolean output scalar indicating if the arrays are close.