mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
9794ec6b8e
...
3e885f583a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e885f583a | ||
|
|
c7af3016eb |
@@ -193,8 +193,8 @@ void time_reductions() {
|
|||||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
auto indices = mlx::core::array({1});
|
auto indices = mx::array({1});
|
||||||
auto updates = mlx::core::reshape(mlx::core::array({NAN}), {1, 1, 1});
|
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||||
std::vector<int> axes{0};
|
std::vector<int> axes{0};
|
||||||
auto b = scatter(a, {indices}, updates, axes);
|
auto b = scatter(a, {indices}, updates, axes);
|
||||||
mx::eval(b);
|
mx::eval(b);
|
||||||
|
|||||||
@@ -186,7 +186,12 @@ struct Max {
|
|||||||
DEFINE_SIMD_REDUCE()
|
DEFINE_SIMD_REDUCE()
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T simd_reduce_impl(T val) {
|
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||||
|
return simd_max(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||||
if (simd_any(val != val)) {
|
if (simd_any(val != val)) {
|
||||||
return static_cast<T>(NAN);
|
return static_cast<T>(NAN);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user