mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Propagate nans in binary ops (#579)
* propagate nans in binary ops * handle empty matmul * cpu minimum/maximum propagate nan * benchmark maximum * add min as well * throw on negative indices with full * verbose on linux * fix matmul for zero K
This commit is contained in:
@@ -58,6 +58,9 @@ struct LessEqual {
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
@@ -67,20 +70,48 @@ struct LogAddExp {
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x >= y ? x : y;
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x <= y ? x : y;
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -389,4 +420,4 @@ instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
|
Reference in New Issue
Block a user