mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 21:21:16 +08:00
Align mlx::core::min op nan propagation with NumPy (#2346)
This commit is contained in:
parent
85873cb162
commit
8c7bc30ce4
@ -203,6 +203,11 @@ void time_reductions() {
|
|||||||
TIME(max_along_0);
|
TIME(max_along_0);
|
||||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
TIME(max_along_1);
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
|
@ -58,6 +58,13 @@ def time_max():
|
|||||||
time_fn(mx.max, a, 0)
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@ -115,6 +122,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
time_max()
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
|
@ -350,7 +350,15 @@ struct MinReduce {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
T operator()(simd::Simd<T, N> x) {
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::min(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
if (simd::any(x != x)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd::min(x);
|
return simd::min(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -164,7 +164,15 @@ struct Min {
|
|||||||
DEFINE_SIMD_REDUCE()
|
DEFINE_SIMD_REDUCE()
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T simd_reduce_impl(T val) {
|
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||||
|
return simd_min(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||||
|
if (simd_any(val != val)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd_min(val);
|
return simd_min(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,11 +184,38 @@ struct Min {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Operator
|
// Operator
|
||||||
U operator()(U a, U b) {
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||||
return a < b ? a : b;
|
return a < b ? a : b;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||||
|
if (metal::isnan(a) || metal::isnan(b)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
} else {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t a, complex64_t b) {
|
||||||
|
bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
|
||||||
|
bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
|
||||||
|
|
||||||
|
if (!real_is_nan && !imag_is_nan) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
} else if (real_is_nan && !imag_is_nan) {
|
||||||
|
return complex64_t(
|
||||||
|
static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
|
||||||
|
} else if (!real_is_nan && imag_is_nan) {
|
||||||
|
return complex64_t(
|
||||||
|
a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
|
||||||
|
} else {
|
||||||
|
return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
template <typename U>
|
template <typename U>
|
||||||
struct Max {
|
struct Max {
|
||||||
DEFINE_SIMD_REDUCE()
|
DEFINE_SIMD_REDUCE()
|
||||||
|
@ -173,7 +173,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
x[idx[0], idx[1]] = mx.nan
|
x[idx[0], idx[1]] = mx.nan
|
||||||
x_np = np.array(x)
|
x_np = np.array(x)
|
||||||
|
|
||||||
for op in ["max"]:
|
for op in ["max", "min"]:
|
||||||
for axis in [0, 1]:
|
for axis in [0, 1]:
|
||||||
out = getattr(mx, op)(x, axis=axis)
|
out = getattr(mx, op)(x, axis=axis)
|
||||||
ref = getattr(np, op)(x_np, axis=axis)
|
ref = getattr(np, op)(x_np, axis=axis)
|
||||||
@ -205,7 +205,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
np_arrays,
|
np_arrays,
|
||||||
):
|
):
|
||||||
for axis in [0, 1]:
|
for axis in [0, 1]:
|
||||||
for op in ["max"]:
|
for op in ["max", "min"]:
|
||||||
out = getattr(mx, op)(mx_arr, axis=axis)
|
out = getattr(mx, op)(mx_arr, axis=axis)
|
||||||
ref = getattr(np, op)(np_arr, axis=axis)
|
ref = getattr(np, op)(np_arr, axis=axis)
|
||||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||||
|
Loading…
Reference in New Issue
Block a user