mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
This commit is contained in:
@@ -77,7 +77,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
|
||||
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
kname << "strided_scan_";
|
||||
if (reverse_) {
|
||||
@@ -119,7 +119,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
|
||||
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (copies.size() > 0) {
|
||||
|
||||
Reference in New Issue
Block a user