Use fewer barriers (#1561)

* use fewer barriers

* comment
This commit is contained in:
Awni Hannun 2024-11-04 10:26:49 -08:00 committed by GitHub
parent 62f297b51d
commit f1951d6cce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 10 deletions

View File

@ -136,13 +136,8 @@ void CommandEncoder::set_input_array(
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier
enc_->memoryBarrier(&r_buf, 1);
// Remove the output
outputs_.erase(it);
}
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
@ -161,19 +156,32 @@ void CommandEncoder::set_output_array(
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {
outputs_.insert(buf);
next_outputs_.insert(buf);
}
}
void CommandEncoder::maybeInsertBarrier() {
if (needs_barrier_) {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
needs_barrier_ = false;
prev_outputs_ = std::move(next_outputs_);
} else {
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
}
next_outputs_.clear();
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreads(grid_dims, group_dims);
}

View File

@ -49,7 +49,7 @@ struct CommandEncoder {
}
~ConcurrentContext() {
enc.concurrent_ = false;
enc.outputs_.insert(
enc.prev_outputs_.insert(
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs_.clear();
}
@ -66,6 +66,7 @@ struct CommandEncoder {
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
@ -84,8 +85,10 @@ struct CommandEncoder {
private:
MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false};
bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_;
std::unordered_set<MTL::Resource*> prev_outputs_;
std::unordered_set<MTL::Resource*> next_outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;