mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Check isnan in maximum / minimum with CPU backend (#2652)
* Check isnan in maximum / minimum with CPU backend * Add tests * fix --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -217,14 +217,20 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::max(a.value, b.value);
|
||||
auto out = Simd<T, N>(asd::max(a.value, b.value));
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
out = select(isnan(b), b, select(isnan(a), a, out));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
|
||||
// TODO add isnan
|
||||
return asd::min(a.value, b.value);
|
||||
auto out = Simd<T, N>(asd::min(a.value, b.value));
|
||||
if constexpr (!std::is_integral_v<T>) {
|
||||
out = select(isnan(b), b, select(isnan(a), a, out));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
|
||||
@@ -4052,3 +4052,24 @@ TEST_CASE("test fp8 conversion") {
|
||||
auto expected = array({-448.0f, 448.0f});
|
||||
CHECK(array_equal(out, expected, true).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test max min with nan") {
|
||||
// Test maximum and minimum with NaN values
|
||||
auto x = array({0.0f, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
|
||||
auto y = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
|
||||
auto expected_max = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
|
||||
auto expected_min = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
|
||||
auto max_result = maximum(x, y);
|
||||
auto min_result = minimum(x, y);
|
||||
CHECK(array_equal(max_result, expected_max, true).item<bool>());
|
||||
CHECK(array_equal(min_result, expected_min, true).item<bool>());
|
||||
|
||||
// Test with all NaN values
|
||||
x = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
|
||||
y = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
|
||||
max_result = maximum(x, y);
|
||||
min_result = minimum(x, y);
|
||||
auto expected = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
|
||||
CHECK(array_equal(max_result, expected, true).item<bool>());
|
||||
CHECK(array_equal(min_result, expected, true).item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user