Split encoders in non-concurrent context with a max ops per encoder (#1085)

* split encoders

* fix race
This commit is contained in:
Awni Hannun
2024-05-09 16:21:02 -07:00
committed by GitHub
parent b21242faf1
commit 06375e6605
18 changed files with 150 additions and 138 deletions

View File

@@ -77,7 +77,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
kname << "strided_scan_";
if (reverse_) {
@@ -119,7 +119,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (copies.size() > 0) {