mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
[Fix] mx.allclose bug with infinite values (#539)
* Added isclose op and fixed comparison with inf values * Added 'equal_nan' to match numpy * format * Add test * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Addressed CR comments * Update python/src/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * nits --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
47
mlx/ops.cpp
47
mlx/ops.cpp
@@ -1127,20 +1127,17 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return logical_or(isposinf(a, s), isneginf(a, s), s);
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s) {
|
||||
if (is_integral(a.dtype())) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
|
||||
}
|
||||
|
||||
array isneginf(const array& a, StreamOrDevice s) {
|
||||
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (is_integral(a.dtype())) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
@@ -1162,11 +1159,43 @@ array allclose(
|
||||
const array& b,
|
||||
double rtol /* = 1e-5 */,
|
||||
double atol /* = 1e-8 */,
|
||||
bool equal_nan /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
return all(isclose(a, b, rtol, atol, equal_nan, s), s);
|
||||
}
|
||||
|
||||
array isclose(
|
||||
const array& a,
|
||||
const array& b,
|
||||
double rtol /* = 1e-5 */,
|
||||
double atol /* = 1e-8 */,
|
||||
bool equal_nan /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
// |a - b| <= atol + rtol * |b|
|
||||
auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s);
|
||||
auto lhs = abs(subtract(a, b, s), s);
|
||||
return all(less_equal(lhs, rhs, s), s);
|
||||
auto out = less_equal(lhs, rhs, s);
|
||||
|
||||
// Correct the result for infinite values.
|
||||
auto any_inf = logical_or(isinf(a, s), isinf(b, s), s);
|
||||
auto both_inf = logical_or(
|
||||
logical_and(isposinf(a, s), isposinf(b, s), s),
|
||||
logical_and(isneginf(a, s), isneginf(b, s), s),
|
||||
s);
|
||||
|
||||
// Convert all elements where either value is infinite to False.
|
||||
out = logical_and(out, logical_not(any_inf, s), s);
|
||||
|
||||
// Convert all the elements where both values are infinite and of the same
|
||||
// sign to True.
|
||||
out = logical_or(out, both_inf, s);
|
||||
|
||||
if (equal_nan) {
|
||||
auto both_nan = logical_and(isnan(a, s), isnan(b, s), s);
|
||||
out = logical_or(out, both_nan, s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
|
||||
|
11
mlx/ops.h
11
mlx/ops.h
@@ -404,6 +404,17 @@ array allclose(
|
||||
const array& b,
|
||||
double rtol = 1e-5,
|
||||
double atol = 1e-8,
|
||||
bool equal_nan = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Returns a boolean array where two arrays are element-wise equal within the
|
||||
* specified tolerance. */
|
||||
array isclose(
|
||||
const array& a,
|
||||
const array& b,
|
||||
double rtol = 1e-5,
|
||||
double atol = 1e-8,
|
||||
bool equal_nan = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user