mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 14:34:37 +08:00
[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values * Added 'equal_nan' to match numpy * format * Add test * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Addressed CR comments * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * nits --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -566,7 +566,7 @@ void init_ops(py::module_& m) {
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
equal_nan (bool): If ``True``, NaNs are treated as equal.
|
||||
equal_nan (bool): If ``True``, NaNs are considered equal.
|
||||
Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
@@ -1648,12 +1648,15 @@ void init_ops(py::module_& m) {
|
||||
"rtol"_a = 1e-5,
|
||||
"atol"_a = 1e-8,
|
||||
py::kw_only(),
|
||||
"equal_nan"_a = false,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Approximate comparison of two arrays.
|
||||
|
||||
Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``.
|
||||
|
||||
The arrays are considered equal if:
|
||||
|
||||
.. code-block::
|
||||
@@ -1668,6 +1671,47 @@ void init_ops(py::module_& m) {
|
||||
b (array): Input array.
|
||||
rtol (float): Relative tolerance.
|
||||
atol (float): Absolute tolerance.
|
||||
equal_nan (bool): If ``True``, NaNs are considered equal.
|
||||
Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
array: The boolean output scalar indicating if the arrays are close.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"isclose",
|
||||
&isclose,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
py::pos_only(),
|
||||
"rtol"_a = 1e-5,
|
||||
"atol"_a = 1e-8,
|
||||
py::kw_only(),
|
||||
"equal_nan"_a = false,
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Returns a boolean array where two arrays are element-wise equal within a tolerance.
|
||||
|
||||
Infinite values are considered equal if they have the same sign, NaN values are
|
||||
not equal unless ``equal_nan`` is ``True``.
|
||||
|
||||
Two values are considered equal if:
|
||||
|
||||
.. code-block::
|
||||
|
||||
abs(a - b) <= (atol + rtol * abs(b))
|
||||
|
||||
Note unlike :func:`array_equal`, this function supports numpy-style
|
||||
broadcasting.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
rtol (float): Relative tolerance.
|
||||
atol (float): Absolute tolerance.
|
||||
equal_nan (bool): If ``True``, NaNs are considered equal.
|
||||
Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
array: The boolean output scalar indicating if the arrays are close.
|
||||
|
Reference in New Issue
Block a user