This commit is contained in:
Awni Hannun 2024-10-24 19:17:46 -07:00 committed by GitHub
parent 430ffef58a
commit dad1b00b13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 11 deletions

View File

@ -130,10 +130,11 @@ CommandEncoder::~CommandEncoder() {
enc_->release();
}
void CommandEncoder::set_array(
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier
@ -149,20 +150,12 @@ void CommandEncoder::set_array(
enc_->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
set_array(a, idx, offset);
}
void CommandEncoder::set_output_array(
array& a,
int idx,
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_array(a, idx, offset);
set_input_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent_) {

View File

@ -83,7 +83,6 @@ struct CommandEncoder {
};
private:
void set_array(const array& a, int idx, int64_t offset);
MTL::ComputeCommandEncoder* enc_;
bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_;