Split encoders in non-concurrent context with a max ops per encoder (#1085)

* split encoders

* fix race
This commit is contained in:
Awni Hannun
2024-05-09 16:21:02 -07:00
committed by GitHub
parent b21242faf1
commit 06375e6605
18 changed files with 150 additions and 138 deletions

View File

@@ -99,7 +99,7 @@ void sdpa_metal(
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
@@ -114,7 +114,7 @@ void sdpa_metal(
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });