This commit is contained in:
Awni Hannun 2024-10-24 19:17:46 -07:00 committed by GitHub
parent 430ffef58a
commit dad1b00b13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 11 deletions

View File

@ -130,10 +130,11 @@ CommandEncoder::~CommandEncoder() {
enc_->release(); enc_->release();
} }
void CommandEncoder::set_array( void CommandEncoder::set_input_array(
const array& a, const array& a,
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr())); auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs_.find(r_buf); it != outputs_.end()) { if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier // Insert a barrier
@ -149,20 +150,12 @@ void CommandEncoder::set_array(
enc_->setBuffer(a_buf, base_offset, idx); enc_->setBuffer(a_buf, base_offset, idx);
} }
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
set_array(a, idx, offset);
}
void CommandEncoder::set_output_array( void CommandEncoder::set_output_array(
array& a, array& a,
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set // Add barriers before adding the output to the output set
set_array(a, idx, offset); set_input_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr()); all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr()); auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent_) { if (concurrent_) {

View File

@ -83,7 +83,6 @@ struct CommandEncoder {
}; };
private: private:
void set_array(const array& a, int idx, int64_t offset);
MTL::ComputeCommandEncoder* enc_; MTL::ComputeCommandEncoder* enc_;
bool concurrent_{false}; bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_; std::unordered_set<MTL::Resource*> outputs_;