mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Allow dynamic ops per buffer based on dispatches and memory (#1864)
* Allow dynamic ops per buffer based on dispatches and memory * add initial arch values
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@@ -124,8 +125,8 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
|
||||
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
|
||||
enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc_->retain();
|
||||
}
|
||||
|
||||
@@ -145,7 +146,9 @@ void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
if (all_inputs_.insert(a.buffer().ptr()).second) {
|
||||
stream_.buffer_sizes += a.data_size();
|
||||
}
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
@@ -190,6 +193,7 @@ void CommandEncoder::dispatch_threadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
stream_.buffer_ops++;
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@@ -197,6 +201,7 @@ void CommandEncoder::dispatch_threads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
stream_.buffer_ops++;
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@@ -209,6 +214,31 @@ Device::Device() {
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
case 'p': // phone
|
||||
max_ops_per_buffer_ = 20;
|
||||
max_mb_per_buffer_ = 40;
|
||||
break;
|
||||
case 'g': // base, pro
|
||||
max_ops_per_buffer_ = 40;
|
||||
max_mb_per_buffer_ = 40;
|
||||
break;
|
||||
case 's': // max
|
||||
max_ops_per_buffer_ = 50;
|
||||
max_mb_per_buffer_ = 50;
|
||||
break;
|
||||
case 'd': // ultra
|
||||
max_ops_per_buffer_ = 50;
|
||||
max_mb_per_buffer_ = 50;
|
||||
break;
|
||||
default: // default to medium
|
||||
max_ops_per_buffer_ = 40;
|
||||
max_mb_per_buffer_ = 40;
|
||||
break;
|
||||
}
|
||||
max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_);
|
||||
max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_);
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
@@ -239,12 +269,13 @@ void Device::new_queue(int index) {
|
||||
}
|
||||
}
|
||||
|
||||
int Device::get_command_buffer_ops(int index) {
|
||||
return get_stream_(index).buffer_ops;
|
||||
}
|
||||
|
||||
void Device::increment_command_buffer_ops(int index) {
|
||||
get_stream_(index).buffer_ops++;
|
||||
bool Device::command_buffer_needs_commit(int index) {
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.buffer_ops > max_ops_per_buffer_ ||
|
||||
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
@@ -267,6 +298,7 @@ void Device::commit_command_buffer(int index) {
|
||||
stream.buffer->release();
|
||||
stream.buffer = nullptr;
|
||||
stream.buffer_ops = 0;
|
||||
stream.buffer_sizes = 0;
|
||||
}
|
||||
|
||||
void Device::add_temporary(array arr, int index) {
|
||||
@@ -351,7 +383,7 @@ void Device::end_encoding(int index) {
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.encoder == nullptr) {
|
||||
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
|
||||
stream.encoder = std::make_unique<CommandEncoder>(stream);
|
||||
stream.fence = std::make_shared<Fence>(device_->newFence());
|
||||
}
|
||||
return *stream.encoder;
|
||||
|
||||
Reference in New Issue
Block a user