From 7aea5b18955dde18fb9c702f7facddbbda7c9254 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 13 Feb 2025 19:18:22 -0800 Subject: [PATCH] Allow dynamic ops per buffer based on dispatches and memory (#1864) * Allow dynamic ops per buffer based on dispatches and memory * add initial arch values --- mlx/backend/metal/device.cpp | 52 +++++++++++++++++++++++++++++------- mlx/backend/metal/device.h | 14 ++++++---- mlx/backend/metal/matmul.cpp | 2 +- mlx/backend/metal/metal.cpp | 4 +-- mlx/utils.h | 11 ++++++-- 5 files changed, 62 insertions(+), 21 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 9379733b6..c6e78d263 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -13,6 +13,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/utils.h" +#include "mlx/utils.h" namespace mlx::core::metal { @@ -124,8 +125,8 @@ MTL::Library* load_library( } // namespace -CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) { - enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); +CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) { + enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent); enc_->retain(); } @@ -145,7 +146,9 @@ void CommandEncoder::set_input_array( const array& a, int idx, int64_t offset /* = 0 */) { - all_inputs_.insert(a.buffer().ptr()); + if (all_inputs_.insert(a.buffer().ptr()).second) { + stream_.buffer_sizes += a.data_size(); + } auto r_buf = static_cast(const_cast(a.buffer().ptr())); needs_barrier_ = needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); @@ -190,6 +193,7 @@ void CommandEncoder::dispatch_threadgroups( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); + stream_.buffer_ops++; enc_->dispatchThreadgroups(grid_dims, group_dims); } @@ -197,6 +201,7 @@ void CommandEncoder::dispatch_threads( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); + stream_.buffer_ops++; enc_->dispatchThreads(grid_dims, group_dims); } @@ -209,6 +214,31 @@ Device::Device() { device_ = load_device(); library_map_ = {{"mlx", load_library(device_)}}; arch_ = std::string(device_->architecture()->name()->utf8String()); + auto arch = arch_.back(); + switch (arch) { + case 'p': // phone + max_ops_per_buffer_ = 20; + max_mb_per_buffer_ = 40; + break; + case 'g': // base, pro + max_ops_per_buffer_ = 40; + max_mb_per_buffer_ = 40; + break; + case 's': // max + max_ops_per_buffer_ = 50; + max_mb_per_buffer_ = 50; + break; + case 'd': // ultra + max_ops_per_buffer_ = 50; + max_mb_per_buffer_ = 50; + break; + default: // default to medium + max_ops_per_buffer_ = 40; + max_mb_per_buffer_ = 40; + break; + } + max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_); + max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_); } Device::~Device() { @@ -239,12 +269,13 @@ void Device::new_queue(int index) { } } -int Device::get_command_buffer_ops(int index) { - return get_stream_(index).buffer_ops; -} - -void Device::increment_command_buffer_ops(int index) { - get_stream_(index).buffer_ops++; +bool Device::command_buffer_needs_commit(int index) { + auto& stream = get_stream_(index); + if (stream.buffer_ops > max_ops_per_buffer_ || + (stream.buffer_sizes >> 20) > max_mb_per_buffer_) { + return true; + } + return false; } MTL::CommandBuffer* Device::get_command_buffer(int index) { @@ -267,6 +298,7 @@ void Device::commit_command_buffer(int index) { stream.buffer->release(); stream.buffer = nullptr; stream.buffer_ops = 0; + stream.buffer_sizes = 0; } void Device::add_temporary(array arr, int index) { @@ -351,7 +383,7 @@ void Device::end_encoding(int index) { CommandEncoder& Device::get_command_encoder(int index) { auto& stream = get_stream_(index); if (stream.encoder == nullptr) { - stream.encoder = std::make_unique(stream.buffer); + stream.encoder = std::make_unique(stream); stream.fence = std::make_shared(device_->newFence()); } return *stream.encoder; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 1c58cdd3d..81d114232 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -38,8 +38,10 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) { using MTLFCList = std::vector>; +struct DeviceStream; + struct CommandEncoder { - CommandEncoder(MTL::CommandBuffer* cbuf); + explicit CommandEncoder(DeviceStream& stream); CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; @@ -115,6 +117,7 @@ struct CommandEncoder { void barrier(); private: + DeviceStream& stream_; MTL::ComputeCommandEncoder* enc_; bool needs_barrier_{false}; bool concurrent_{false}; @@ -147,10 +150,10 @@ struct DeviceStream { // Used to allow thread-safe access to the outputs map std::mutex fence_mtx; - // The buffer and buffer op count are updated - // between command buffers + // Data updated between command buffers MTL::CommandBuffer* buffer{nullptr}; int buffer_ops{0}; + size_t buffer_sizes{0}; // The command encoder, fence, and temporaries are updated between command // encoders @@ -176,8 +179,7 @@ class Device { void new_queue(int index); MTL::CommandBuffer* get_command_buffer(int index); - int get_command_buffer_ops(int index); - void increment_command_buffer_ops(int index); + bool command_buffer_needs_commit(int index); void commit_command_buffer(int index); CommandEncoder& get_command_encoder(int index); void end_encoding(int index); @@ -267,6 +269,8 @@ class Device { std::unordered_map library_map_; const MTL::ResidencySet* residency_set_{nullptr}; std::string arch_; + int max_ops_per_buffer_; + int max_mb_per_buffer_; }; Device& device(mlx::core::Device); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 707e43fa1..4d3cf21ee 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -109,7 +109,7 @@ std::tuple check_transpose( /////////////////////////////////////////////////////////////////////////////// #define GEMM_TPARAM_MACRO(devc) \ - if (devc == 'g') { /* Small device */ \ + if (devc == 'g' || devc == 'p') { /* Small device */ \ if (!transpose_a && transpose_b) { /* nt */ \ bm = 64; \ bn = 32; \ diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index a8bc716c1..4933a0855 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -29,7 +29,6 @@ std::function make_task(array arr, bool signal) { auto s = arr.primitive().stream(); auto& d = metal::device(s.device); auto command_buffer = d.get_command_buffer(s.index); - d.increment_command_buffer_ops(s.index); for (auto& input : arr.inputs()) { if (input.event().valid() && @@ -68,8 +67,7 @@ std::function make_task(array arr, bool signal) { out.set_status(array::Status::evaluated); } - if (signal || - d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) { + if (signal || d.command_buffer_needs_commit(s.index)) { if (signal) { encode_signal(arr.event()); } diff --git a/mlx/utils.h b/mlx/utils.h index df6bc0ec6..11ac37be8 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -122,11 +122,18 @@ inline int bfs_max_width() { return bfs_max_width_; } -inline int max_ops_per_buffer() { - static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10); +inline int max_ops_per_buffer(int default_value) { + static int max_ops_per_buffer_ = + get_var("MLX_MAX_OPS_PER_BUFFER", default_value); return max_ops_per_buffer_; } +inline int max_mb_per_buffer(int default_value) { + static int max_mb_per_buffer_ = + get_var("MLX_MAX_MB_PER_BUFFER", default_value); + return max_mb_per_buffer_; +} + inline bool metal_fast_synch() { static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0); return metal_fast_synch;