* fix scan

* improve grid size

* fix cpu cummax
This commit is contained in:
Awni Hannun
2024-06-05 14:21:58 -07:00
committed by GitHub
parent 0fe6895893
commit 496315fe1d
4 changed files with 23 additions and 8 deletions

View File

@@ -309,6 +309,7 @@ template <
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {

View File

@@ -65,16 +65,16 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
// Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
int elements_per_simd = n_reads * 32;
constexpr int simd_size = 32;
int elements_per_simd = n_reads * simd_size;
int thread_groups = in.size() / size;
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (size < n_reads * 1024) {
thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) *
elements_per_simd;
} else if (size < n_reads * 2048) {
if (size <= n_reads * 1024) {
thread_group_size =
((size / 2 + elements_per_simd - 1) / elements_per_simd) *
elements_per_simd;
((size + elements_per_simd - 1) / elements_per_simd) * simd_size;
} else if (size <= n_reads * 2048) {
thread_group_size =
((size / 2 + elements_per_simd - 1) / elements_per_simd) * simd_size;
}
thread_group_size = std::min(
thread_group_size,