From dad1b00b139743f35c593c9f2438f9797daec82d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 24 Oct 2024 19:17:46 -0700 Subject: [PATCH] fix (#1523) --- mlx/backend/metal/device.cpp | 13 +++---------- mlx/backend/metal/device.h | 1 - 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index cc0694ca8..37254a8a0 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -130,10 +130,11 @@ CommandEncoder::~CommandEncoder() { enc_->release(); } -void CommandEncoder::set_array( +void CommandEncoder::set_input_array( const array& a, int idx, int64_t offset /* = 0 */) { + all_inputs_.insert(a.buffer().ptr()); auto r_buf = static_cast(const_cast(a.buffer().ptr())); if (auto it = outputs_.find(r_buf); it != outputs_.end()) { // Insert a barrier @@ -149,20 +150,12 @@ void CommandEncoder::set_array( enc_->setBuffer(a_buf, base_offset, idx); } -void CommandEncoder::set_input_array( - const array& a, - int idx, - int64_t offset /* = 0 */) { - all_inputs_.insert(a.buffer().ptr()); - set_array(a, idx, offset); -} - void CommandEncoder::set_output_array( array& a, int idx, int64_t offset /* = 0 */) { // Add barriers before adding the output to the output set - set_array(a, idx, offset); + set_input_array(a, idx, offset); all_outputs_.insert(a.buffer().ptr()); auto buf = static_cast(a.buffer().ptr()); if (concurrent_) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index d15a4aaf8..a3b613d68 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -83,7 +83,6 @@ struct CommandEncoder { }; private: - void set_array(const array& a, int idx, int64_t offset); MTL::ComputeCommandEncoder* enc_; bool concurrent_{false}; std::unordered_set outputs_;