mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values * Added 'equal_nan' to match numpy * format * Add test * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Addressed CR comments * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * nits --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
87b7fa9ba2
commit
2463496471
47
mlx/ops.cpp
47
mlx/ops.cpp
@ -1127,20 +1127,17 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
return logical_or(isposinf(a, s), isneginf(a, s), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
if (is_integral(a.dtype())) {
|
if (is_integral(a.dtype())) {
|
||||||
return full(a.shape(), false, bool_, s);
|
return full(a.shape(), false, bool_, s);
|
||||||
}
|
}
|
||||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array isposinf(const array& a, StreamOrDevice s) {
|
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
if (is_integral(a.dtype())) {
|
|
||||||
return full(a.shape(), false, bool_, s);
|
|
||||||
}
|
|
||||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
|
||||||
}
|
|
||||||
|
|
||||||
array isneginf(const array& a, StreamOrDevice s) {
|
|
||||||
if (is_integral(a.dtype())) {
|
if (is_integral(a.dtype())) {
|
||||||
return full(a.shape(), false, bool_, s);
|
return full(a.shape(), false, bool_, s);
|
||||||
}
|
}
|
||||||
@ -1162,11 +1159,43 @@ array allclose(
|
|||||||
const array& b,
|
const array& b,
|
||||||
double rtol /* = 1e-5 */,
|
double rtol /* = 1e-5 */,
|
||||||
double atol /* = 1e-8 */,
|
double atol /* = 1e-8 */,
|
||||||
|
bool equal_nan /* = false */,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
return all(isclose(a, b, rtol, atol, equal_nan, s), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array isclose(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
double rtol /* = 1e-5 */,
|
||||||
|
double atol /* = 1e-8 */,
|
||||||
|
bool equal_nan /* = false */,
|
||||||
StreamOrDevice s /* = {}*/) {
|
StreamOrDevice s /* = {}*/) {
|
||||||
// |a - b| <= atol + rtol * |b|
|
// |a - b| <= atol + rtol * |b|
|
||||||
auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s);
|
auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s);
|
||||||
auto lhs = abs(subtract(a, b, s), s);
|
auto lhs = abs(subtract(a, b, s), s);
|
||||||
return all(less_equal(lhs, rhs, s), s);
|
auto out = less_equal(lhs, rhs, s);
|
||||||
|
|
||||||
|
// Correct the result for infinite values.
|
||||||
|
auto any_inf = logical_or(isinf(a, s), isinf(b, s), s);
|
||||||
|
auto both_inf = logical_or(
|
||||||
|
logical_and(isposinf(a, s), isposinf(b, s), s),
|
||||||
|
logical_and(isneginf(a, s), isneginf(b, s), s),
|
||||||
|
s);
|
||||||
|
|
||||||
|
// Convert all elements where either value is infinite to False.
|
||||||
|
out = logical_and(out, logical_not(any_inf, s), s);
|
||||||
|
|
||||||
|
// Convert all the elements where both values are infinite and of the same
|
||||||
|
// sign to True.
|
||||||
|
out = logical_or(out, both_inf, s);
|
||||||
|
|
||||||
|
if (equal_nan) {
|
||||||
|
auto both_nan = logical_and(isnan(a, s), isnan(b, s), s);
|
||||||
|
out = logical_or(out, both_nan, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||||
|
11
mlx/ops.h
11
mlx/ops.h
@ -404,6 +404,17 @@ array allclose(
|
|||||||
const array& b,
|
const array& b,
|
||||||
double rtol = 1e-5,
|
double rtol = 1e-5,
|
||||||
double atol = 1e-8,
|
double atol = 1e-8,
|
||||||
|
bool equal_nan = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Returns a boolean array where two arrays are element-wise equal within the
|
||||||
|
* specified tolerance. */
|
||||||
|
array isclose(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
double rtol = 1e-5,
|
||||||
|
double atol = 1e-8,
|
||||||
|
bool equal_nan = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -566,7 +566,7 @@ void init_ops(py::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
a (array): Input array or scalar.
|
a (array): Input array or scalar.
|
||||||
b (array): Input array or scalar.
|
b (array): Input array or scalar.
|
||||||
equal_nan (bool): If ``True``, NaNs are treated as equal.
|
equal_nan (bool): If ``True``, NaNs are considered equal.
|
||||||
Defaults to ``False``.
|
Defaults to ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1648,12 +1648,15 @@ void init_ops(py::module_& m) {
|
|||||||
"rtol"_a = 1e-5,
|
"rtol"_a = 1e-5,
|
||||||
"atol"_a = 1e-8,
|
"atol"_a = 1e-8,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
|
"equal_nan"_a = false,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, stream: Union[None, Stream, Device] = None) -> array
|
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Approximate comparison of two arrays.
|
Approximate comparison of two arrays.
|
||||||
|
|
||||||
|
Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``.
|
||||||
|
|
||||||
The arrays are considered equal if:
|
The arrays are considered equal if:
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
@ -1668,6 +1671,47 @@ void init_ops(py::module_& m) {
|
|||||||
b (array): Input array.
|
b (array): Input array.
|
||||||
rtol (float): Relative tolerance.
|
rtol (float): Relative tolerance.
|
||||||
atol (float): Absolute tolerance.
|
atol (float): Absolute tolerance.
|
||||||
|
equal_nan (bool): If ``True``, NaNs are considered equal.
|
||||||
|
Defaults to ``False``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The boolean output scalar indicating if the arrays are close.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"isclose",
|
||||||
|
&isclose,
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"rtol"_a = 1e-5,
|
||||||
|
"atol"_a = 1e-8,
|
||||||
|
py::kw_only(),
|
||||||
|
"equal_nan"_a = false,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Returns a boolean array where two arrays are element-wise equal within a tolerance.
|
||||||
|
|
||||||
|
Infinite values are considered equal if they have the same sign, NaN values are
|
||||||
|
not equal unless ``equal_nan`` is ``True``.
|
||||||
|
|
||||||
|
Two values are considered equal if:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
abs(a - b) <= (atol + rtol * abs(b))
|
||||||
|
|
||||||
|
Note unlike :func:`array_equal`, this function supports numpy-style
|
||||||
|
broadcasting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
b (array): Input array.
|
||||||
|
rtol (float): Relative tolerance.
|
||||||
|
atol (float): Absolute tolerance.
|
||||||
|
equal_nan (bool): If ``True``, NaNs are considered equal.
|
||||||
|
Defaults to ``False``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The boolean output scalar indicating if the arrays are close.
|
array: The boolean output scalar indicating if the arrays are close.
|
||||||
|
@ -855,6 +855,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertFalse(mx.allclose(a, b, 0.01).item())
|
self.assertFalse(mx.allclose(a, b, 0.01).item())
|
||||||
self.assertTrue(mx.allclose(a, b, 0.01, 0.1).item())
|
self.assertTrue(mx.allclose(a, b, 0.01, 0.1).item())
|
||||||
|
|
||||||
|
c = mx.array(float("inf"))
|
||||||
|
self.assertTrue(mx.allclose(c, c).item())
|
||||||
|
|
||||||
|
def test_isclose(self):
|
||||||
|
a = mx.array([float("inf"), float("inf"), float("-inf")])
|
||||||
|
b = mx.array([float("inf"), float("-inf"), float("-inf")])
|
||||||
|
|
||||||
|
self.assertListEqual(mx.isclose(a, b).tolist(), [True, False, True])
|
||||||
|
|
||||||
|
a = mx.array([np.nan])
|
||||||
|
self.assertListEqual(mx.isclose(a, a).tolist(), [False])
|
||||||
|
|
||||||
|
a = mx.array([np.nan])
|
||||||
|
self.assertListEqual(mx.isclose(a, a, equal_nan=True).tolist(), [True])
|
||||||
|
|
||||||
def test_all(self):
|
def test_all(self):
|
||||||
a = mx.array([[True, False], [True, True]])
|
a = mx.array([[True, False], [True, True]])
|
||||||
|
|
||||||
|
@ -514,6 +514,9 @@ TEST_CASE("test is inf") {
|
|||||||
array y(inf);
|
array y(inf);
|
||||||
CHECK(isinf(y).item<bool>());
|
CHECK(isinf(y).item<bool>());
|
||||||
|
|
||||||
|
auto neginf = -std::numeric_limits<float>::infinity();
|
||||||
|
CHECK(isinf(array(neginf)).item<bool>());
|
||||||
|
|
||||||
array z = identity(7);
|
array z = identity(7);
|
||||||
CHECK_FALSE(any(isinf(z)).item<bool>());
|
CHECK_FALSE(any(isinf(z)).item<bool>());
|
||||||
|
|
||||||
@ -545,6 +548,36 @@ TEST_CASE("test all close") {
|
|||||||
CHECK(allclose(x, y, 0.01, 0.1).item<bool>());
|
CHECK(allclose(x, y, 0.01, 0.1).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test is close") {
|
||||||
|
{
|
||||||
|
array a({1.0, std::numeric_limits<float>::infinity()});
|
||||||
|
array b({1.0, std::numeric_limits<float>::infinity()});
|
||||||
|
CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
array a({1.0, -std::numeric_limits<float>::infinity()});
|
||||||
|
array b({1.0, -std::numeric_limits<float>::infinity()});
|
||||||
|
CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
array a({1.0, std::numeric_limits<float>::infinity()});
|
||||||
|
array b({1.0, -std::numeric_limits<float>::infinity()});
|
||||||
|
CHECK(array_equal(isclose(a, b), array({true, false})).item<bool>());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
array a({1.0, std::nan("1"), std::nan("1")});
|
||||||
|
array b({1.0, std::nan("1"), 2.0});
|
||||||
|
CHECK(array_equal(isclose(a, b), array({true, false, false})).item<bool>());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
array a({1.0, std::nan("1"), std::nan("1")});
|
||||||
|
array b({1.0, std::nan("1"), 2.0});
|
||||||
|
CHECK(
|
||||||
|
array_equal(isclose(a, b, 1e-5, 1e-8, true), array({true, true, false}))
|
||||||
|
.item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test reduction ops") {
|
TEST_CASE("test reduction ops") {
|
||||||
// Check shapes and throws correctly
|
// Check shapes and throws correctly
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user