From fc6494cac7763d83c633c888fed9b10d9ffc910a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 5 Nov 2024 14:42:51 -0800 Subject: [PATCH] register in task submission --- mlx/backend/metal/device.cpp | 51 +++++++++++++++++++----------------- mlx/backend/metal/device.h | 25 ++++++++---------- mlx/backend/metal/metal.cpp | 12 +++++++++ 3 files changed, 50 insertions(+), 38 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index d7b758e4d..22fe6b6e2 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -134,7 +134,6 @@ void CommandEncoder::set_input_array( const array& a, int idx, int64_t offset /* = 0 */) { - all_inputs_.insert(a.buffer().ptr()); auto r_buf = static_cast(const_cast(a.buffer().ptr())); needs_barrier_ = needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); @@ -151,7 +150,6 @@ void CommandEncoder::set_output_array( int64_t offset /* = 0 */) { // Add barriers before adding the output to the output set set_input_array(a, idx, offset); - all_outputs_.insert(a.buffer().ptr()); auto buf = static_cast(a.buffer().ptr()); if (concurrent_) { concurrent_outputs_.insert(buf); @@ -185,6 +183,24 @@ void CommandEncoder::dispatchThreads( enc_->dispatchThreads(grid_dims, group_dims); } +void DeviceStream::register_inputs(const std::vector& arrays) { + std::lock_guard lk(fence_mtx); + for (auto& a : arrays) { + auto buf = a.buffer().ptr(); + if (auto it = outputs.find(buf); it != outputs.end()) { + waiting_on.insert(it->second); + } + } +} + +void DeviceStream::register_outputs(const std::vector& arrays) { + for (auto& a : arrays) { + if (a.data() != nullptr) { + all_outputs.insert(a.buffer().ptr()); + } + } +} + Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); @@ -281,38 +297,22 @@ void Device::end_encoding(int index) { // boundaries. These can be removed early from the encoders inputs and // outputs since they don't need synchronization. auto& enc = *stream.encoder; - // Remove temporaries from inputs and outputs - for (auto& t : stream.temporaries) { - if (t.data() != nullptr) { - enc.outputs().erase(t.buffer().ptr()); - enc.inputs().erase(t.buffer().ptr()); - } - } - // Keep references to the fences we waited on and put them - // in the completion handler so they are not prematurely released - std::unordered_set> waiting_on; + for (auto& f : stream.waiting_on) { + enc->waitForFence(f->fence); + } { std::lock_guard lk(stream.fence_mtx); - for (auto in : enc.inputs()) { - if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { - // If we've already waited on a fence, don't wait on it again. - if (waiting_on.find(it->second) == waiting_on.end()) { - enc->waitForFence(it->second->fence); - waiting_on.insert(it->second); - } - } - } - for (auto out : enc.outputs()) { + for (auto out : stream.all_outputs) { stream.outputs[out] = stream.fence; } } enc->updateFence(stream.fence->fence); stream.buffer->addCompletedHandler( [&stream, - waiting_on = std::move(waiting_on), + waiting_on = std::move(stream.waiting_on), fence = std::move(stream.fence), - outputs = std::move(enc.outputs()), + outputs = std::move(stream.all_outputs), temporaries = std::move(stream.temporaries)](MTL::CommandBuffer*) mutable { temporaries.clear(); @@ -325,6 +325,9 @@ void Device::end_encoding(int index) { } } }); + } else { + stream.all_outputs.clear(); + stream.waiting_on.clear(); } stream.encoder = nullptr; } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bd366dc47..d2287c679 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -73,16 +73,6 @@ struct CommandEncoder { } ~CommandEncoder(); - // Inputs to all kernels in the encoder including temporaries - std::unordered_set& inputs() { - return all_inputs_; - }; - - // Outputs of all kernels in the encoder including temporaries - std::unordered_set outputs() { - return all_outputs_; - }; - private: MTL::ComputeCommandEncoder* enc_; bool needs_barrier_{false}; @@ -90,8 +80,6 @@ struct CommandEncoder { std::unordered_set prev_outputs_; std::unordered_set next_outputs_; std::unordered_set concurrent_outputs_; - std::unordered_set all_inputs_; - std::unordered_set all_outputs_; }; struct Fence { @@ -121,11 +109,15 @@ struct DeviceStream { MTL::CommandBuffer* buffer{nullptr}; int buffer_ops{0}; - // The command encoder, fence, and temporaries are updated between command - // encoders + void register_inputs(const std::vector& inputs); + void register_outputs(const std::vector& inputs); + + // The following variables are all reset between command encoders std::unique_ptr encoder{nullptr}; std::shared_ptr fence; std::vector temporaries; + std::unordered_set> waiting_on; + std::unordered_set all_outputs; }; class Device { @@ -190,10 +182,15 @@ class Device { void set_residency_set(const MTL::ResidencySet* residency_set); + DeviceStream& get_stream(int index) { + return stream_map_.find(index)->second; + } + private: DeviceStream& get_stream_(int index) { return stream_map_.find(index)->second; } + MTL::Library* get_library_cache_(const std::string& name); MTL::Library* get_library_(const std::string& name); diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 2a5e6334e..37ff56953 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -60,8 +60,20 @@ std::function make_task(array arr, bool signal) { inputs = arr.inputs(); } + // if (dstream.active_without_fence()) { + // for (auto& in : arr.inputs()) { + // // If any input needs a fence + // if (command_encoder.needs_fence(in)) { + // end_encoding(); + // break; + // } + // } + // } + auto& dstream = d.get_stream(s.index); + dstream.register_inputs(arr.inputs()); debug_set_primitive_buffer_label(command_buffer, arr.primitive()); arr.primitive().eval_gpu(arr.inputs(), outputs); + dstream.register_outputs(outputs); } std::vector> buffers; for (auto& in : arr.inputs()) {