mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 09:18:12 +08:00
fix complex reduce + nan propagation in min and max (#2377)
This commit is contained in:
@@ -38,14 +38,13 @@ inline __host__ __device__ complex_t<T> operator%(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
|
||||
return (a.real() * a.real() + a.imag() * a.imag()) <
|
||||
(b.real() * b.real() + b.imag() * b.imag());
|
||||
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
|
||||
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
|
||||
return b < a;
|
||||
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@@ -69,6 +69,18 @@ struct Prod {
|
||||
struct Min {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
if (isnan(a.real()) || isnan(a.imag())) {
|
||||
return a;
|
||||
}
|
||||
if (isnan(b.real()) || isnan(b.imag())) {
|
||||
return b;
|
||||
}
|
||||
} else if constexpr (!cuda::std::is_integral_v<T>) {
|
||||
if (isnan(a) || isnan(b)) {
|
||||
return cuda::std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
@@ -81,6 +93,18 @@ struct Min {
|
||||
struct Max {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T operator()(T a, T b) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
if (isnan(a.real()) || isnan(a.imag())) {
|
||||
return a;
|
||||
}
|
||||
if (isnan(b.real()) || isnan(b.imag())) {
|
||||
return b;
|
||||
}
|
||||
} else if constexpr (!cuda::std::is_integral_v<T>) {
|
||||
if (isnan(a) || isnan(b)) {
|
||||
return cuda::std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user