From f1951d6cce2c14dd707c768d9234ecb18f36e08a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 4 Nov 2024 10:26:49 -0800 Subject: [PATCH] Use fewer barriers (#1561) * use fewer barriers * comment --- mlx/backend/metal/device.cpp | 24 ++++++++++++++++-------- mlx/backend/metal/device.h | 7 +++++-- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 8074547ca..d7b758e4d 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -136,13 +136,8 @@ void CommandEncoder::set_input_array( int64_t offset /* = 0 */) { all_inputs_.insert(a.buffer().ptr()); auto r_buf = static_cast(const_cast(a.buffer().ptr())); - if (auto it = outputs_.find(r_buf); it != outputs_.end()) { - // Insert a barrier - enc_->memoryBarrier(&r_buf, 1); - - // Remove the output - outputs_.erase(it); - } + needs_barrier_ = + needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); auto a_buf = static_cast(a.buffer().ptr()); auto base_offset = a.data() - static_cast(const_cast(a_buf)->contents()); @@ -161,19 +156,32 @@ void CommandEncoder::set_output_array( if (concurrent_) { concurrent_outputs_.insert(buf); } else { - outputs_.insert(buf); + next_outputs_.insert(buf); } } +void CommandEncoder::maybeInsertBarrier() { + if (needs_barrier_) { + enc_->memoryBarrier(MTL::BarrierScopeBuffers); + needs_barrier_ = false; + prev_outputs_ = std::move(next_outputs_); + } else { + prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end()); + } + next_outputs_.clear(); +} + void CommandEncoder::dispatchThreadgroups( MTL::Size grid_dims, MTL::Size group_dims) { + maybeInsertBarrier(); enc_->dispatchThreadgroups(grid_dims, group_dims); } void CommandEncoder::dispatchThreads( MTL::Size grid_dims, MTL::Size group_dims) { + maybeInsertBarrier(); enc_->dispatchThreads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index fe32cc738..bd366dc47 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -49,7 +49,7 @@ struct CommandEncoder { } ~ConcurrentContext() { enc.concurrent_ = false; - enc.outputs_.insert( + enc.prev_outputs_.insert( enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end()); enc.concurrent_outputs_.clear(); } @@ -66,6 +66,7 @@ struct CommandEncoder { void set_output_array(array& a, int idx, int64_t offset = 0); void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); + void maybeInsertBarrier(); ConcurrentContext start_concurrent() { return ConcurrentContext(*this); @@ -84,8 +85,10 @@ struct CommandEncoder { private: MTL::ComputeCommandEncoder* enc_; + bool needs_barrier_{false}; bool concurrent_{false}; - std::unordered_set outputs_; + std::unordered_set prev_outputs_; + std::unordered_set next_outputs_; std::unordered_set concurrent_outputs_; std::unordered_set all_inputs_; std::unordered_set all_outputs_;