mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
This commit is contained in:
@@ -99,7 +99,7 @@ void sdpa_metal(
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
@@ -114,7 +114,7 @@ void sdpa_metal(
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
|
||||
|
||||
Reference in New Issue
Block a user