mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Explicit barriers with concurrent dispatch (#977)
This commit is contained in:
@@ -71,7 +71,7 @@ void sdpa_metal(
|
||||
|
||||
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
|
||||
kname_partials << kname_suffix;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname_partials.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -87,15 +87,15 @@ void sdpa_metal(
|
||||
MLXScaledDotProductAttentionParams params{
|
||||
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
|
||||
|
||||
set_array_buffer(compute_encoder, q, 0);
|
||||
set_array_buffer(compute_encoder, k, 1);
|
||||
set_array_buffer(compute_encoder, v, 2);
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 4);
|
||||
set_array_buffer(compute_encoder, o_partial, 5);
|
||||
set_array_buffer(compute_encoder, p_lse, 6);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 7);
|
||||
compute_encoder.set_input_array(o_partial, 5);
|
||||
compute_encoder.set_input_array(p_lse, 6);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 7);
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
@@ -104,12 +104,12 @@ void sdpa_metal(
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
compute_encoder->setComputePipelineState(kernel_accum);
|
||||
set_array_buffer(compute_encoder, o_partial, 0);
|
||||
set_array_buffer(compute_encoder, p_lse, 1);
|
||||
set_array_buffer(compute_encoder, p_rowmaxes, 2);
|
||||
compute_encoder.set_input_array(o_partial, 0);
|
||||
compute_encoder.set_input_array(p_lse, 1);
|
||||
compute_encoder.set_input_array(p_rowmaxes, 2);
|
||||
compute_encoder->setBytes(
|
||||
¶ms, sizeof(MLXScaledDotProductAttentionParams), 3);
|
||||
set_array_buffer(compute_encoder, out, 4);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
|
||||
Reference in New Issue
Block a user