mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
9794ec6b8e
...
3e885f583a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e885f583a | ||
|
|
c7af3016eb |
@@ -193,8 +193,8 @@ void time_reductions() {
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
|
||||
auto indices = mlx::core::array({1});
|
||||
auto updates = mlx::core::reshape(mlx::core::array({NAN}), {1, 1, 1});
|
||||
auto indices = mx::array({1});
|
||||
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||
std::vector<int> axes{0};
|
||||
auto b = scatter(a, {indices}, updates, axes);
|
||||
mx::eval(b);
|
||||
|
||||
@@ -186,7 +186,12 @@ struct Max {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce_impl(T val) {
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||
return simd_max(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
|
||||
if (simd_any(val != val)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user