mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16.
This commit is contained in:
@@ -325,7 +325,15 @@ struct MaxReduce {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
T operator()(simd::Simd<T, N> x) {
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::max(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
if (simd::any(x != x)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd::max(x);
|
return simd::max(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@@ -527,10 +535,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||||
|
Reference in New Issue
Block a user