diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 13e74d292..65d8d39a2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1127,20 +1127,17 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) { } array isinf(const array& a, StreamOrDevice s /* = {} */) { + return logical_or(isposinf(a, s), isneginf(a, s), s); +} + +array isposinf(const array& a, StreamOrDevice s /* = {} */) { if (is_integral(a.dtype())) { return full(a.shape(), false, bool_, s); } return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); } -array isposinf(const array& a, StreamOrDevice s) { - if (is_integral(a.dtype())) { - return full(a.shape(), false, bool_, s); - } - return equal(a, array(std::numeric_limits::infinity(), a.dtype()), s); -} - -array isneginf(const array& a, StreamOrDevice s) { +array isneginf(const array& a, StreamOrDevice s /* = {} */) { if (is_integral(a.dtype())) { return full(a.shape(), false, bool_, s); } @@ -1162,11 +1159,43 @@ array allclose( const array& b, double rtol /* = 1e-5 */, double atol /* = 1e-8 */, + bool equal_nan /* = false */, + StreamOrDevice s /* = {}*/) { + return all(isclose(a, b, rtol, atol, equal_nan, s), s); +} + +array isclose( + const array& a, + const array& b, + double rtol /* = 1e-5 */, + double atol /* = 1e-8 */, + bool equal_nan /* = false */, StreamOrDevice s /* = {}*/) { // |a - b| <= atol + rtol * |b| auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s); auto lhs = abs(subtract(a, b, s), s); - return all(less_equal(lhs, rhs, s), s); + auto out = less_equal(lhs, rhs, s); + + // Correct the result for infinite values. + auto any_inf = logical_or(isinf(a, s), isinf(b, s), s); + auto both_inf = logical_or( + logical_and(isposinf(a, s), isposinf(b, s), s), + logical_and(isneginf(a, s), isneginf(b, s), s), + s); + + // Convert all elements where either value is infinite to False. + out = logical_and(out, logical_not(any_inf, s), s); + + // Convert all the elements where both values are infinite and of the same + // sign to True. + out = logical_or(out, both_inf, s); + + if (equal_nan) { + auto both_nan = logical_and(isnan(a, s), isnan(b, s), s); + out = logical_or(out, both_nan, s); + } + + return out; } array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) { diff --git a/mlx/ops.h b/mlx/ops.h index 1f2ace43f..19889165c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -404,6 +404,17 @@ array allclose( const array& b, double rtol = 1e-5, double atol = 1e-8, + bool equal_nan = false, + StreamOrDevice s = {}); + +/** Returns a boolean array where two arrays are element-wise equal within the + * specified tolerance. */ +array isclose( + const array& a, + const array& b, + double rtol = 1e-5, + double atol = 1e-8, + bool equal_nan = false, StreamOrDevice s = {}); /** diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7f2ce27ee..1156391f7 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -566,7 +566,7 @@ void init_ops(py::module_& m) { Args: a (array): Input array or scalar. b (array): Input array or scalar. - equal_nan (bool): If ``True``, NaNs are treated as equal. + equal_nan (bool): If ``True``, NaNs are considered equal. Defaults to ``False``. Returns: @@ -1648,12 +1648,15 @@ void init_ops(py::module_& m) { "rtol"_a = 1e-5, "atol"_a = 1e-8, py::kw_only(), + "equal_nan"_a = false, "stream"_a = none, R"pbdoc( - allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array + allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array Approximate comparison of two arrays. + Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``. + The arrays are considered equal if: .. code-block:: @@ -1668,6 +1671,47 @@ void init_ops(py::module_& m) { b (array): Input array. rtol (float): Relative tolerance. atol (float): Absolute tolerance. + equal_nan (bool): If ``True``, NaNs are considered equal. + Defaults to ``False``. + + Returns: + array: The boolean output scalar indicating if the arrays are close. + )pbdoc"); + m.def( + "isclose", + &isclose, + "a"_a, + "b"_a, + py::pos_only(), + "rtol"_a = 1e-5, + "atol"_a = 1e-8, + py::kw_only(), + "equal_nan"_a = false, + "stream"_a = none, + R"pbdoc( + isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array + + Returns a boolean array where two arrays are element-wise equal within a tolerance. + + Infinite values are considered equal if they have the same sign, NaN values are + not equal unless ``equal_nan`` is ``True``. + + Two values are considered equal if: + + .. code-block:: + + abs(a - b) <= (atol + rtol * abs(b)) + + Note unlike :func:`array_equal`, this function supports numpy-style + broadcasting. + + Args: + a (array): Input array. + b (array): Input array. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + equal_nan (bool): If ``True``, NaNs are considered equal. + Defaults to ``False``. Returns: array: The boolean output scalar indicating if the arrays are close. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9f3226bdd..bbc6e6a84 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -855,6 +855,21 @@ class TestOps(mlx_tests.MLXTestCase): self.assertFalse(mx.allclose(a, b, 0.01).item()) self.assertTrue(mx.allclose(a, b, 0.01, 0.1).item()) + c = mx.array(float("inf")) + self.assertTrue(mx.allclose(c, c).item()) + + def test_isclose(self): + a = mx.array([float("inf"), float("inf"), float("-inf")]) + b = mx.array([float("inf"), float("-inf"), float("-inf")]) + + self.assertListEqual(mx.isclose(a, b).tolist(), [True, False, True]) + + a = mx.array([np.nan]) + self.assertListEqual(mx.isclose(a, a).tolist(), [False]) + + a = mx.array([np.nan]) + self.assertListEqual(mx.isclose(a, a, equal_nan=True).tolist(), [True]) + def test_all(self): a = mx.array([[True, False], [True, True]]) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c17e25572..927ad2874 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -514,6 +514,9 @@ TEST_CASE("test is inf") { array y(inf); CHECK(isinf(y).item()); + auto neginf = -std::numeric_limits::infinity(); + CHECK(isinf(array(neginf)).item()); + array z = identity(7); CHECK_FALSE(any(isinf(z)).item()); @@ -545,6 +548,36 @@ TEST_CASE("test all close") { CHECK(allclose(x, y, 0.01, 0.1).item()); } +TEST_CASE("test is close") { + { + array a({1.0, std::numeric_limits::infinity()}); + array b({1.0, std::numeric_limits::infinity()}); + CHECK(array_equal(isclose(a, b), array({true, true})).item()); + } + { + array a({1.0, -std::numeric_limits::infinity()}); + array b({1.0, -std::numeric_limits::infinity()}); + CHECK(array_equal(isclose(a, b), array({true, true})).item()); + } + { + array a({1.0, std::numeric_limits::infinity()}); + array b({1.0, -std::numeric_limits::infinity()}); + CHECK(array_equal(isclose(a, b), array({true, false})).item()); + } + { + array a({1.0, std::nan("1"), std::nan("1")}); + array b({1.0, std::nan("1"), 2.0}); + CHECK(array_equal(isclose(a, b), array({true, false, false})).item()); + } + { + array a({1.0, std::nan("1"), std::nan("1")}); + array b({1.0, std::nan("1"), 2.0}); + CHECK( + array_equal(isclose(a, b, 1e-5, 1e-8, true), array({true, true, false})) + .item()); + } +} + TEST_CASE("test reduction ops") { // Check shapes and throws correctly {