Split encoders in non-concurrent context with a max ops per encoder (#1085)

* split encoders

* fix race
This commit is contained in:
Awni Hannun
2024-05-09 16:21:02 -07:00
committed by GitHub
parent b21242faf1
commit 06375e6605
18 changed files with 150 additions and 138 deletions

View File

@@ -356,7 +356,7 @@ void steel_matmul_conv_groups(
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -468,7 +468,7 @@ void steel_matmul(
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -493,7 +493,7 @@ void steel_matmul(
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -581,7 +581,7 @@ void steel_matmul(
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -748,7 +748,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -968,7 +968,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -1038,7 +1038,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -1063,7 +1063,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -1160,7 +1160,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -1346,7 +1346,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(out_mask, 10);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -1566,7 +1566,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -1656,7 +1656,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
set_vector_bytes(compute_encoder, batch_strides_B, 15);
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(