mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
@@ -234,7 +234,7 @@ void scan_dispatch(
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
: std::numeric_limits<U>::min();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
|
@@ -309,6 +309,7 @@ template <
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Share the prefix
|
||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||
|
@@ -65,16 +65,16 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Compute the thread grid
|
||||
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
|
||||
int elements_per_simd = n_reads * 32;
|
||||
constexpr int simd_size = 32;
|
||||
int elements_per_simd = n_reads * simd_size;
|
||||
int thread_groups = in.size() / size;
|
||||
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (size < n_reads * 1024) {
|
||||
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
|
||||
elements_per_simd;
|
||||
} else if (size < n_reads * 2048) {
|
||||
if (size <= n_reads * 1024) {
|
||||
thread_group_size =
|
||||
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
|
||||
elements_per_simd;
|
||||
((size + elements_per_simd - 1) / elements_per_simd) * simd_size;
|
||||
} else if (size <= n_reads * 2048) {
|
||||
thread_group_size =
|
||||
((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size;
|
||||
}
|
||||
thread_group_size = std::min(
|
||||
thread_group_size,
|
||||
|
Reference in New Issue
Block a user