diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index cc0694ca8..37254a8a0 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -130,10 +130,11 @@ CommandEncoder::~CommandEncoder() { enc_->release(); } -void CommandEncoder::set_array( +void CommandEncoder::set_input_array( const array& a, int idx, int64_t offset /* = 0 */) { + all_inputs_.insert(a.buffer().ptr()); auto r_buf = static_cast(const_cast(a.buffer().ptr())); if (auto it = outputs_.find(r_buf); it != outputs_.end()) { // Insert a barrier @@ -149,20 +150,12 @@ void CommandEncoder::set_array( enc_->setBuffer(a_buf, base_offset, idx); } -void CommandEncoder::set_input_array( - const array& a, - int idx, - int64_t offset /* = 0 */) { - all_inputs_.insert(a.buffer().ptr()); - set_array(a, idx, offset); -} - void CommandEncoder::set_output_array( array& a, int idx, int64_t offset /* = 0 */) { // Add barriers before adding the output to the output set - set_array(a, idx, offset); + set_input_array(a, idx, offset); all_outputs_.insert(a.buffer().ptr()); auto buf = static_cast(a.buffer().ptr()); if (concurrent_) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index d15a4aaf8..a3b613d68 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -83,7 +83,6 @@ struct CommandEncoder { }; private: - void set_array(const array& a, int idx, int64_t offset); MTL::ComputeCommandEncoder* enc_; bool concurrent_{false}; std::unordered_set outputs_;