Allow dynamic ops per buffer based on dispatches and memory (#1864)

* Allow dynamic ops per buffer based on dispatches and memory

* add initial arch values
This commit is contained in:
Awni Hannun 2025-02-13 19:18:22 -08:00 committed by GitHub
parent 9733e16496
commit 7aea5b1895
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 21 deletions

View File

@ -13,6 +13,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace mlx::core::metal { namespace mlx::core::metal {
@ -124,8 +125,8 @@ MTL::Library* load_library(
} // namespace } // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) { CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc_->retain(); enc_->retain();
} }
@ -145,7 +146,9 @@ void CommandEncoder::set_input_array(
const array& a, const array& a,
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr()); if (all_inputs_.insert(a.buffer().ptr()).second) {
stream_.buffer_sizes += a.data_size();
}
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr())); auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
needs_barrier_ = needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end()); needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
@ -190,6 +193,7 @@ void CommandEncoder::dispatch_threadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
stream_.buffer_ops++;
enc_->dispatchThreadgroups(grid_dims, group_dims); enc_->dispatchThreadgroups(grid_dims, group_dims);
} }
@ -197,6 +201,7 @@ void CommandEncoder::dispatch_threads(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
stream_.buffer_ops++;
enc_->dispatchThreads(grid_dims, group_dims); enc_->dispatchThreads(grid_dims, group_dims);
} }
@ -209,6 +214,31 @@ Device::Device() {
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}}; library_map_ = {{"mlx", load_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {
case 'p': // phone
max_ops_per_buffer_ = 20;
max_mb_per_buffer_ = 40;
break;
case 'g': // base, pro
max_ops_per_buffer_ = 40;
max_mb_per_buffer_ = 40;
break;
case 's': // max
max_ops_per_buffer_ = 50;
max_mb_per_buffer_ = 50;
break;
case 'd': // ultra
max_ops_per_buffer_ = 50;
max_mb_per_buffer_ = 50;
break;
default: // default to medium
max_ops_per_buffer_ = 40;
max_mb_per_buffer_ = 40;
break;
}
max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_);
max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_);
} }
Device::~Device() { Device::~Device() {
@ -239,12 +269,13 @@ void Device::new_queue(int index) {
} }
} }
int Device::get_command_buffer_ops(int index) { bool Device::command_buffer_needs_commit(int index) {
return get_stream_(index).buffer_ops; auto& stream = get_stream_(index);
if (stream.buffer_ops > max_ops_per_buffer_ ||
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
return true;
} }
return false;
void Device::increment_command_buffer_ops(int index) {
get_stream_(index).buffer_ops++;
} }
MTL::CommandBuffer* Device::get_command_buffer(int index) { MTL::CommandBuffer* Device::get_command_buffer(int index) {
@ -267,6 +298,7 @@ void Device::commit_command_buffer(int index) {
stream.buffer->release(); stream.buffer->release();
stream.buffer = nullptr; stream.buffer = nullptr;
stream.buffer_ops = 0; stream.buffer_ops = 0;
stream.buffer_sizes = 0;
} }
void Device::add_temporary(array arr, int index) { void Device::add_temporary(array arr, int index) {
@ -351,7 +383,7 @@ void Device::end_encoding(int index) {
CommandEncoder& Device::get_command_encoder(int index) { CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index); auto& stream = get_stream_(index);
if (stream.encoder == nullptr) { if (stream.encoder == nullptr) {
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer); stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence()); stream.fence = std::make_shared<Fence>(device_->newFence());
} }
return *stream.encoder; return *stream.encoder;

View File

@ -38,8 +38,10 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
using MTLFCList = using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>; std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct DeviceStream;
struct CommandEncoder { struct CommandEncoder {
CommandEncoder(MTL::CommandBuffer* cbuf); explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(const CommandEncoder&) = delete; CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete;
@ -115,6 +117,7 @@ struct CommandEncoder {
void barrier(); void barrier();
private: private:
DeviceStream& stream_;
MTL::ComputeCommandEncoder* enc_; MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false}; bool needs_barrier_{false};
bool concurrent_{false}; bool concurrent_{false};
@ -147,10 +150,10 @@ struct DeviceStream {
// Used to allow thread-safe access to the outputs map // Used to allow thread-safe access to the outputs map
std::mutex fence_mtx; std::mutex fence_mtx;
// The buffer and buffer op count are updated // Data updated between command buffers
// between command buffers
MTL::CommandBuffer* buffer{nullptr}; MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0}; int buffer_ops{0};
size_t buffer_sizes{0};
// The command encoder, fence, and temporaries are updated between command // The command encoder, fence, and temporaries are updated between command
// encoders // encoders
@ -176,8 +179,7 @@ class Device {
void new_queue(int index); void new_queue(int index);
MTL::CommandBuffer* get_command_buffer(int index); MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index); bool command_buffer_needs_commit(int index);
void increment_command_buffer_ops(int index);
void commit_command_buffer(int index); void commit_command_buffer(int index);
CommandEncoder& get_command_encoder(int index); CommandEncoder& get_command_encoder(int index);
void end_encoding(int index); void end_encoding(int index);
@ -267,6 +269,8 @@ class Device {
std::unordered_map<std::string, MTL::Library*> library_map_; std::unordered_map<std::string, MTL::Library*> library_map_;
const MTL::ResidencySet* residency_set_{nullptr}; const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_; std::string arch_;
int max_ops_per_buffer_;
int max_mb_per_buffer_;
}; };
Device& device(mlx::core::Device); Device& device(mlx::core::Device);

View File

@ -109,7 +109,7 @@ std::tuple<bool, int64_t, array> check_transpose(
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
#define GEMM_TPARAM_MACRO(devc) \ #define GEMM_TPARAM_MACRO(devc) \
if (devc == 'g') { /* Small device */ \ if (devc == 'g' || devc == 'p') { /* Small device */ \
if (!transpose_a && transpose_b) { /* nt */ \ if (!transpose_a && transpose_b) { /* nt */ \
bm = 64; \ bm = 64; \
bn = 32; \ bn = 32; \

View File

@ -29,7 +29,6 @@ std::function<void()> make_task(array arr, bool signal) {
auto s = arr.primitive().stream(); auto s = arr.primitive().stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index); auto command_buffer = d.get_command_buffer(s.index);
d.increment_command_buffer_ops(s.index);
for (auto& input : arr.inputs()) { for (auto& input : arr.inputs()) {
if (input.event().valid() && if (input.event().valid() &&
@ -68,8 +67,7 @@ std::function<void()> make_task(array arr, bool signal) {
out.set_status(array::Status::evaluated); out.set_status(array::Status::evaluated);
} }
if (signal || if (signal || d.command_buffer_needs_commit(s.index)) {
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
if (signal) { if (signal) {
encode_signal(arr.event()); encode_signal(arr.event());
} }

View File

@ -122,11 +122,18 @@ inline int bfs_max_width() {
return bfs_max_width_; return bfs_max_width_;
} }
inline int max_ops_per_buffer() { inline int max_ops_per_buffer(int default_value) {
static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10); static int max_ops_per_buffer_ =
get_var("MLX_MAX_OPS_PER_BUFFER", default_value);
return max_ops_per_buffer_; return max_ops_per_buffer_;
} }
inline int max_mb_per_buffer(int default_value) {
static int max_mb_per_buffer_ =
get_var("MLX_MAX_MB_PER_BUFFER", default_value);
return max_mb_per_buffer_;
}
inline bool metal_fast_synch() { inline bool metal_fast_synch() {
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0); static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
return metal_fast_synch; return metal_fast_synch;