Check isnan in maximum / minimum with CPU backend (#2652)

* Check isnan in maximum / minimum with CPU backend

* Add tests

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
AN Long
2025-11-04 01:51:14 +09:00
committed by GitHub
parent 50514a6146
commit 1ff2b713b6
2 changed files with 31 additions and 4 deletions

View File

@@ -217,14 +217,20 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::max(a.value, b.value);
auto out = Simd<T, N>(asd::max(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
}
template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::min(a.value, b.value);
auto out = Simd<T, N>(asd::min(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
}
template <typename T, int N>

View File

@@ -4052,3 +4052,24 @@ TEST_CASE("test fp8 conversion") {
auto expected = array({-448.0f, 448.0f});
CHECK(array_equal(out, expected, true).item<bool>());
}
TEST_CASE("test max min with nan") {
// Test maximum and minimum with NaN values
auto x = array({0.0f, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
auto y = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
auto expected_max = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
auto expected_min = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
auto max_result = maximum(x, y);
auto min_result = minimum(x, y);
CHECK(array_equal(max_result, expected_max, true).item<bool>());
CHECK(array_equal(min_result, expected_min, true).item<bool>());
// Test with all NaN values
x = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
y = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
max_result = maximum(x, y);
min_result = minimum(x, y);
auto expected = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
CHECK(array_equal(max_result, expected, true).item<bool>());
CHECK(array_equal(min_result, expected, true).item<bool>());
}