Align mlx::core::min op nan propagation with NumPy (#2346)

This commit is contained in:
jhavukainen
2025-07-10 06:20:43 -07:00
committed by GitHub
parent 85873cb162
commit 8c7bc30ce4
5 changed files with 62 additions and 6 deletions

View File

@@ -203,6 +203,11 @@ void time_reductions() {
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
}
void time_gather_scatter() {

View File

@@ -58,6 +58,13 @@ def time_max():
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -115,6 +122,7 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_min()
time_max()
time_maximum()
time_exp()