mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Allow dynamic ops per buffer based on dispatches and memory (#1864)
* Allow dynamic ops per buffer based on dispatches and memory * add initial arch values
This commit is contained in:
parent
9733e16496
commit
7aea5b1895
@ -13,6 +13,7 @@
|
|||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
@ -124,8 +125,8 @@ MTL::Library* load_library(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
|
CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
|
||||||
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||||
enc_->retain();
|
enc_->retain();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,7 +146,9 @@ void CommandEncoder::set_input_array(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int idx,
|
int idx,
|
||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
all_inputs_.insert(a.buffer().ptr());
|
if (all_inputs_.insert(a.buffer().ptr()).second) {
|
||||||
|
stream_.buffer_sizes += a.data_size();
|
||||||
|
}
|
||||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||||
needs_barrier_ =
|
needs_barrier_ =
|
||||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||||
@ -190,6 +193,7 @@ void CommandEncoder::dispatch_threadgroups(
|
|||||||
MTL::Size grid_dims,
|
MTL::Size grid_dims,
|
||||||
MTL::Size group_dims) {
|
MTL::Size group_dims) {
|
||||||
maybeInsertBarrier();
|
maybeInsertBarrier();
|
||||||
|
stream_.buffer_ops++;
|
||||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,6 +201,7 @@ void CommandEncoder::dispatch_threads(
|
|||||||
MTL::Size grid_dims,
|
MTL::Size grid_dims,
|
||||||
MTL::Size group_dims) {
|
MTL::Size group_dims) {
|
||||||
maybeInsertBarrier();
|
maybeInsertBarrier();
|
||||||
|
stream_.buffer_ops++;
|
||||||
enc_->dispatchThreads(grid_dims, group_dims);
|
enc_->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,6 +214,31 @@ Device::Device() {
|
|||||||
device_ = load_device();
|
device_ = load_device();
|
||||||
library_map_ = {{"mlx", load_library(device_)}};
|
library_map_ = {{"mlx", load_library(device_)}};
|
||||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||||
|
auto arch = arch_.back();
|
||||||
|
switch (arch) {
|
||||||
|
case 'p': // phone
|
||||||
|
max_ops_per_buffer_ = 20;
|
||||||
|
max_mb_per_buffer_ = 40;
|
||||||
|
break;
|
||||||
|
case 'g': // base, pro
|
||||||
|
max_ops_per_buffer_ = 40;
|
||||||
|
max_mb_per_buffer_ = 40;
|
||||||
|
break;
|
||||||
|
case 's': // max
|
||||||
|
max_ops_per_buffer_ = 50;
|
||||||
|
max_mb_per_buffer_ = 50;
|
||||||
|
break;
|
||||||
|
case 'd': // ultra
|
||||||
|
max_ops_per_buffer_ = 50;
|
||||||
|
max_mb_per_buffer_ = 50;
|
||||||
|
break;
|
||||||
|
default: // default to medium
|
||||||
|
max_ops_per_buffer_ = 40;
|
||||||
|
max_mb_per_buffer_ = 40;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_);
|
||||||
|
max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
@ -239,12 +269,13 @@ void Device::new_queue(int index) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int Device::get_command_buffer_ops(int index) {
|
bool Device::command_buffer_needs_commit(int index) {
|
||||||
return get_stream_(index).buffer_ops;
|
auto& stream = get_stream_(index);
|
||||||
|
if (stream.buffer_ops > max_ops_per_buffer_ ||
|
||||||
|
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
void Device::increment_command_buffer_ops(int index) {
|
|
||||||
get_stream_(index).buffer_ops++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||||
@ -267,6 +298,7 @@ void Device::commit_command_buffer(int index) {
|
|||||||
stream.buffer->release();
|
stream.buffer->release();
|
||||||
stream.buffer = nullptr;
|
stream.buffer = nullptr;
|
||||||
stream.buffer_ops = 0;
|
stream.buffer_ops = 0;
|
||||||
|
stream.buffer_sizes = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::add_temporary(array arr, int index) {
|
void Device::add_temporary(array arr, int index) {
|
||||||
@ -351,7 +383,7 @@ void Device::end_encoding(int index) {
|
|||||||
CommandEncoder& Device::get_command_encoder(int index) {
|
CommandEncoder& Device::get_command_encoder(int index) {
|
||||||
auto& stream = get_stream_(index);
|
auto& stream = get_stream_(index);
|
||||||
if (stream.encoder == nullptr) {
|
if (stream.encoder == nullptr) {
|
||||||
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
|
stream.encoder = std::make_unique<CommandEncoder>(stream);
|
||||||
stream.fence = std::make_shared<Fence>(device_->newFence());
|
stream.fence = std::make_shared<Fence>(device_->newFence());
|
||||||
}
|
}
|
||||||
return *stream.encoder;
|
return *stream.encoder;
|
||||||
|
@ -38,8 +38,10 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
|||||||
using MTLFCList =
|
using MTLFCList =
|
||||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||||
|
|
||||||
|
struct DeviceStream;
|
||||||
|
|
||||||
struct CommandEncoder {
|
struct CommandEncoder {
|
||||||
CommandEncoder(MTL::CommandBuffer* cbuf);
|
explicit CommandEncoder(DeviceStream& stream);
|
||||||
CommandEncoder(const CommandEncoder&) = delete;
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
|
|
||||||
@ -115,6 +117,7 @@ struct CommandEncoder {
|
|||||||
void barrier();
|
void barrier();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
DeviceStream& stream_;
|
||||||
MTL::ComputeCommandEncoder* enc_;
|
MTL::ComputeCommandEncoder* enc_;
|
||||||
bool needs_barrier_{false};
|
bool needs_barrier_{false};
|
||||||
bool concurrent_{false};
|
bool concurrent_{false};
|
||||||
@ -147,10 +150,10 @@ struct DeviceStream {
|
|||||||
// Used to allow thread-safe access to the outputs map
|
// Used to allow thread-safe access to the outputs map
|
||||||
std::mutex fence_mtx;
|
std::mutex fence_mtx;
|
||||||
|
|
||||||
// The buffer and buffer op count are updated
|
// Data updated between command buffers
|
||||||
// between command buffers
|
|
||||||
MTL::CommandBuffer* buffer{nullptr};
|
MTL::CommandBuffer* buffer{nullptr};
|
||||||
int buffer_ops{0};
|
int buffer_ops{0};
|
||||||
|
size_t buffer_sizes{0};
|
||||||
|
|
||||||
// The command encoder, fence, and temporaries are updated between command
|
// The command encoder, fence, and temporaries are updated between command
|
||||||
// encoders
|
// encoders
|
||||||
@ -176,8 +179,7 @@ class Device {
|
|||||||
|
|
||||||
void new_queue(int index);
|
void new_queue(int index);
|
||||||
MTL::CommandBuffer* get_command_buffer(int index);
|
MTL::CommandBuffer* get_command_buffer(int index);
|
||||||
int get_command_buffer_ops(int index);
|
bool command_buffer_needs_commit(int index);
|
||||||
void increment_command_buffer_ops(int index);
|
|
||||||
void commit_command_buffer(int index);
|
void commit_command_buffer(int index);
|
||||||
CommandEncoder& get_command_encoder(int index);
|
CommandEncoder& get_command_encoder(int index);
|
||||||
void end_encoding(int index);
|
void end_encoding(int index);
|
||||||
@ -267,6 +269,8 @@ class Device {
|
|||||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||||
const MTL::ResidencySet* residency_set_{nullptr};
|
const MTL::ResidencySet* residency_set_{nullptr};
|
||||||
std::string arch_;
|
std::string arch_;
|
||||||
|
int max_ops_per_buffer_;
|
||||||
|
int max_mb_per_buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Device& device(mlx::core::Device);
|
Device& device(mlx::core::Device);
|
||||||
|
@ -109,7 +109,7 @@ std::tuple<bool, int64_t, array> check_transpose(
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define GEMM_TPARAM_MACRO(devc) \
|
#define GEMM_TPARAM_MACRO(devc) \
|
||||||
if (devc == 'g') { /* Small device */ \
|
if (devc == 'g' || devc == 'p') { /* Small device */ \
|
||||||
if (!transpose_a && transpose_b) { /* nt */ \
|
if (!transpose_a && transpose_b) { /* nt */ \
|
||||||
bm = 64; \
|
bm = 64; \
|
||||||
bn = 32; \
|
bn = 32; \
|
||||||
|
@ -29,7 +29,6 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
auto s = arr.primitive().stream();
|
auto s = arr.primitive().stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
d.increment_command_buffer_ops(s.index);
|
|
||||||
|
|
||||||
for (auto& input : arr.inputs()) {
|
for (auto& input : arr.inputs()) {
|
||||||
if (input.event().valid() &&
|
if (input.event().valid() &&
|
||||||
@ -68,8 +67,7 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
out.set_status(array::Status::evaluated);
|
out.set_status(array::Status::evaluated);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (signal ||
|
if (signal || d.command_buffer_needs_commit(s.index)) {
|
||||||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
|
|
||||||
if (signal) {
|
if (signal) {
|
||||||
encode_signal(arr.event());
|
encode_signal(arr.event());
|
||||||
}
|
}
|
||||||
|
11
mlx/utils.h
11
mlx/utils.h
@ -122,11 +122,18 @@ inline int bfs_max_width() {
|
|||||||
return bfs_max_width_;
|
return bfs_max_width_;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int max_ops_per_buffer() {
|
inline int max_ops_per_buffer(int default_value) {
|
||||||
static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10);
|
static int max_ops_per_buffer_ =
|
||||||
|
get_var("MLX_MAX_OPS_PER_BUFFER", default_value);
|
||||||
return max_ops_per_buffer_;
|
return max_ops_per_buffer_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int max_mb_per_buffer(int default_value) {
|
||||||
|
static int max_mb_per_buffer_ =
|
||||||
|
get_var("MLX_MAX_MB_PER_BUFFER", default_value);
|
||||||
|
return max_mb_per_buffer_;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool metal_fast_synch() {
|
inline bool metal_fast_synch() {
|
||||||
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
|
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
|
||||||
return metal_fast_synch;
|
return metal_fast_synch;
|
||||||
|
Loading…
Reference in New Issue
Block a user