Split encoders in non-concurrent context with a max ops per encoder (#1085)

* split encoders

* fix race
This commit is contained in:
Awni Hannun
2024-05-09 16:21:02 -07:00
committed by GitHub
parent b21242faf1
commit 06375e6605
18 changed files with 150 additions and 138 deletions

View File

@@ -78,7 +78,7 @@ void single_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
template <bool ARGSORT>
@@ -155,7 +155,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do merges
@@ -190,7 +190,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do merge
@@ -214,7 +214,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}