mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use load_vector in arg_reduce (#2439)
This commit is contained in:
@@ -131,20 +131,6 @@ inline __device__ void store_vector(
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for accessing strided data.
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
T it;
|
||||
int64_t stride;
|
||||
|
||||
__host__ __device__ StridedIterator(T it, int64_t stride)
|
||||
: it(it), stride(stride) {}
|
||||
|
||||
__host__ __device__ auto operator[](int i) const {
|
||||
return it[i * stride];
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user