mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 12:31:13 +08:00
Align mlx::core::min op nan propagation with NumPy (#2346)
This commit is contained in:
parent
85873cb162
commit
8c7bc30ce4
@ -203,6 +203,11 @@ void time_reductions() {
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
|
||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||
TIME(min_along_0);
|
||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||
TIME(min_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
@ -58,6 +58,13 @@ def time_max():
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_min():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.min, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@ -115,6 +122,7 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_min()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
|
@ -350,7 +350,15 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
if (simd::any(x != x)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
|
@ -164,7 +164,15 @@ struct Min {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce_impl(T val) {
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||
return simd_min(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||
if (simd_any(val != val)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd_min(val);
|
||||
}
|
||||
|
||||
@ -176,11 +184,38 @@ struct Min {
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
if (metal::isnan(a) || metal::isnan(b)) {
|
||||
return static_cast<T>(NAN);
|
||||
} else {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t a, complex64_t b) {
|
||||
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
|
||||
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
|
||||
|
||||
if (!real_is_nan && !imag_is_nan) {
|
||||
return a < b ? a : b;
|
||||
} else if (real_is_nan && !imag_is_nan) {
|
||||
return complex64_t(
|
||||
static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
|
||||
} else if (!real_is_nan && imag_is_nan) {
|
||||
return complex64_t(
|
||||
a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
|
||||
} else {
|
||||
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
|
||||
}
|
||||
};
|
||||
};
|
||||
template <typename U>
|
||||
struct Max {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
@ -173,7 +173,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
x[idx[0], idx[1]] = mx.nan
|
||||
x_np = np.array(x)
|
||||
|
||||
for op in ["max"]:
|
||||
for op in ["max", "min"]:
|
||||
for axis in [0, 1]:
|
||||
out = getattr(mx, op)(x, axis=axis)
|
||||
ref = getattr(np, op)(x_np, axis=axis)
|
||||
@ -205,7 +205,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
np_arrays,
|
||||
):
|
||||
for axis in [0, 1]:
|
||||
for op in ["max"]:
|
||||
for op in ["max", "min"]:
|
||||
out = getattr(mx, op)(mx_arr, axis=axis)
|
||||
ref = getattr(np, op)(np_arr, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
Loading…
Reference in New Issue
Block a user