Explicit barriers with concurrent dispatch (#977)

This commit is contained in:
Awni Hannun
2024-04-10 21:45:31 -07:00
committed by GitHub
parent 8580d997ff
commit 12d4507ee3
21 changed files with 326 additions and 267 deletions

View File

@@ -71,7 +71,7 @@ void sdpa_metal(
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
kname_partials << kname_suffix;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_partials.str());
compute_encoder->setComputePipelineState(kernel);
@@ -87,15 +87,15 @@ void sdpa_metal(
MLXScaledDotProductAttentionParams params{
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
set_array_buffer(compute_encoder, q, 0);
set_array_buffer(compute_encoder, k, 1);
set_array_buffer(compute_encoder, v, 2);
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 4);
set_array_buffer(compute_encoder, o_partial, 5);
set_array_buffer(compute_encoder, p_lse, 6);
set_array_buffer(compute_encoder, p_rowmaxes, 7);
compute_encoder.set_input_array(o_partial, 5);
compute_encoder.set_input_array(p_lse, 6);
compute_encoder.set_input_array(p_rowmaxes, 7);
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
@@ -104,12 +104,12 @@ void sdpa_metal(
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
compute_encoder->setComputePipelineState(kernel_accum);
set_array_buffer(compute_encoder, o_partial, 0);
set_array_buffer(compute_encoder, p_lse, 1);
set_array_buffer(compute_encoder, p_rowmaxes, 2);
compute_encoder.set_input_array(o_partial, 0);
compute_encoder.set_input_array(p_lse, 1);
compute_encoder.set_input_array(p_rowmaxes, 2);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder.set_output_array(out, 4);
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);