[Fix] mx.allclose bug with infinite values (#539)

* Added isclose op and fixed comparison with inf values

* Added 'equal_nan' to match numpy

* format

* Add test

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Addressed CR comments

* Update python/src/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* nits

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Rifur13
2024-01-25 23:47:06 -05:00
committed by GitHub
parent 87b7fa9ba2
commit 2463496471
5 changed files with 143 additions and 11 deletions

View File

@@ -1127,20 +1127,17 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) {
}
array isinf(const array& a, StreamOrDevice s /* = {} */) {
return logical_or(isposinf(a, s), isneginf(a, s), s);
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isposinf(const array& a, StreamOrDevice s) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isneginf(const array& a, StreamOrDevice s) {
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
@@ -1162,11 +1159,43 @@ array allclose(
const array& b,
double rtol /* = 1e-5 */,
double atol /* = 1e-8 */,
bool equal_nan /* = false */,
StreamOrDevice s /* = {}*/) {
return all(isclose(a, b, rtol, atol, equal_nan, s), s);
}
array isclose(
const array& a,
const array& b,
double rtol /* = 1e-5 */,
double atol /* = 1e-8 */,
bool equal_nan /* = false */,
StreamOrDevice s /* = {}*/) {
// |a - b| <= atol + rtol * |b|
auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s);
auto lhs = abs(subtract(a, b, s), s);
return all(less_equal(lhs, rhs, s), s);
auto out = less_equal(lhs, rhs, s);
// Correct the result for infinite values.
auto any_inf = logical_or(isinf(a, s), isinf(b, s), s);
auto both_inf = logical_or(
logical_and(isposinf(a, s), isposinf(b, s), s),
logical_and(isneginf(a, s), isneginf(b, s), s),
s);
// Convert all elements where either value is infinite to False.
out = logical_and(out, logical_not(any_inf, s), s);
// Convert all the elements where both values are infinite and of the same
// sign to True.
out = logical_or(out, both_inf, s);
if (equal_nan) {
auto both_nan = logical_and(isnan(a, s), isnan(b, s), s);
out = logical_or(out, both_nan, s);
}
return out;
}
array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {

View File

@@ -404,6 +404,17 @@ array allclose(
const array& b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {});
/** Returns a boolean array where two arrays are element-wise equal within the
* specified tolerance. */
array isclose(
const array& a,
const array& b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {});
/**