mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
This commit is contained in:
@@ -78,7 +78,7 @@ void single_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
@@ -155,7 +155,7 @@ void multi_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merges
|
||||
@@ -190,7 +190,7 @@ void multi_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merge
|
||||
@@ -214,7 +214,7 @@ void multi_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user