Use fewer barriers (#1561)

* use fewer barriers

* comment
This commit is contained in:
Awni Hannun
2024-11-04 10:26:49 -08:00
committed by GitHub
parent 62f297b51d
commit f1951d6cce
2 changed files with 21 additions and 10 deletions

View File

@@ -136,13 +136,8 @@ void CommandEncoder::set_input_array(
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier
enc_->memoryBarrier(&r_buf, 1);
// Remove the output
outputs_.erase(it);
}
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
@@ -161,19 +156,32 @@ void CommandEncoder::set_output_array(
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {
outputs_.insert(buf);
next_outputs_.insert(buf);
}
}
void CommandEncoder::maybeInsertBarrier() {
if (needs_barrier_) {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
needs_barrier_ = false;
prev_outputs_ = std::move(next_outputs_);
} else {
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
}
next_outputs_.clear();
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
enc_->dispatchThreads(grid_dims, group_dims);
}