mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-22 17:41:20 +08:00
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
This commit is contained in:
parent
b21242faf1
commit
06375e6605
@ -336,7 +336,7 @@ void Compiled::eval_gpu(
|
||||
MTL::Size grid_dims(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
@ -347,7 +347,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
@ -137,7 +137,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
@ -247,7 +247,7 @@ void slow_conv_2D_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
@ -512,7 +512,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
@ -641,7 +641,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
@ -689,7 +689,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ void copy_gpu_inplace(
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
@ -135,7 +135,7 @@ void copy_gpu_inplace(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,7 @@ namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
@ -37,7 +38,6 @@ auto load_device() {
|
||||
}
|
||||
return device;
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
MTL::Device* device,
|
||||
const char* path) {
|
||||
@ -116,6 +116,33 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreadgroups(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreads(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_split() {
|
||||
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
num_dispatches = 0;
|
||||
outputs.clear();
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
@ -130,9 +157,6 @@ Device::~Device() {
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& e : encoder_map_) {
|
||||
(*e.second)->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
@ -169,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) {
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
|
||||
}
|
||||
if (bit == buffer_map_.end()) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::new_command_buffer(int index) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
bit = buffer_map_.insert({index, {0, cb}}).first;
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
return buffer_map_.insert({index, {0, cb}}).first->second.second;
|
||||
return bit->second.second;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
@ -200,25 +223,15 @@ void Device::commit_command_buffer(int index) {
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit != encoder_map_.end()) {
|
||||
(*eit->second)->endEncoding();
|
||||
(*eit->second)->release();
|
||||
encoder_map_.erase(eit);
|
||||
}
|
||||
encoder_map_.erase(index);
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder =
|
||||
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_
|
||||
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
|
||||
.first;
|
||||
eit =
|
||||
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
||||
}
|
||||
return *(eit->second);
|
||||
}
|
||||
|
@ -37,8 +37,10 @@ using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||
: enc(enc), concurrent(false) {};
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
};
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
@ -89,13 +91,25 @@ struct CommandEncoder {
|
||||
}
|
||||
}
|
||||
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
|
||||
private:
|
||||
void maybe_split();
|
||||
|
||||
int num_dispatches{0};
|
||||
MTL::CommandBuffer* cbuf;
|
||||
MTL::ComputeCommandEncoder* enc;
|
||||
bool concurrent;
|
||||
bool concurrent{false};
|
||||
std::unordered_set<MTL::Resource*> outputs;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
||||
};
|
||||
@ -112,7 +126,6 @@ class Device {
|
||||
};
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* new_command_buffer(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
|
@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
|
@ -107,7 +107,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -216,7 +216,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
@ -286,7 +286,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -356,7 +356,7 @@ void steel_matmul_conv_groups(
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
@ -468,7 +468,7 @@ void steel_matmul(
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
@ -493,7 +493,7 @@ void steel_matmul(
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
@ -581,7 +581,7 @@ void steel_matmul(
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
@ -748,7 +748,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
@ -968,7 +968,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
@ -1038,7 +1038,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_output_array(C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
@ -1063,7 +1063,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
@ -1160,7 +1160,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
@ -1346,7 +1346,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(out_mask, 10);
|
||||
set_vector_bytes(compute_encoder, mask_strides, 13);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
@ -1566,7 +1566,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
@ -1656,7 +1656,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
set_vector_bytes(compute_encoder, batch_strides_B, 15);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
@ -27,24 +27,6 @@ int max_ops_per_buffer() {
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (command_buffer == nullptr ||
|
||||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||
if (command_buffer != nullptr) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
|
||||
d.commit_command_buffer(s.index);
|
||||
}
|
||||
command_buffer = d.new_command_buffer(s.index);
|
||||
}
|
||||
d.increment_command_buffer_ops(s.index);
|
||||
return command_buffer;
|
||||
}
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
@ -58,7 +40,10 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
auto task = [arr = std::move(arr), signal]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
d.increment_command_buffer_ops(s.index);
|
||||
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
@ -91,11 +76,13 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
arr.detach();
|
||||
}
|
||||
|
||||
if (signal) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(arr.event().raw_event().get()),
|
||||
arr.event().value());
|
||||
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||
d.end_encoding(s.index);
|
||||
if (signal) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(arr.event().raw_event().get()),
|
||||
arr.event().value());
|
||||
}
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers), event = arr.event()](
|
||||
@ -103,7 +90,8 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
@ -120,14 +108,12 @@ std::function<void()> make_synchronize_task(
|
||||
return [s, p = std::move(p)]() {
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
if (cb == nullptr) {
|
||||
cb = d.new_command_buffer(s.index);
|
||||
} else {
|
||||
d.end_encoding(s.index);
|
||||
}
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
p->set_value();
|
||||
};
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ void RMSNorm::eval_gpu(
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
16 * 8, 0); // minimum of 16 bytes
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
@ -190,7 +190,7 @@ void RMSNormVJP::eval_gpu(
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
ReductionPlan plan(
|
||||
@ -282,7 +282,7 @@ void LayerNorm::eval_gpu(
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
|
||||
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
@ -401,7 +401,7 @@ void LayerNormVJP::eval_gpu(
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
||||
|
@ -107,7 +107,7 @@ void binary_op(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
@ -117,7 +117,7 @@ void binary_op(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,7 +201,7 @@ void binary_op(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads =
|
||||
@ -212,7 +212,7 @@ void binary_op(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@ -288,7 +288,7 @@ void ternary_op(
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
@ -298,7 +298,7 @@ void ternary_op(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@ -351,7 +351,7 @@ void unary_op(
|
||||
int ndim = in.ndim();
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||
}
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -428,7 +428,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -523,7 +523,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@ -834,7 +834,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
|
||||
}
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -65,7 +65,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmv kernel
|
||||
@ -92,7 +92,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_t kernel
|
||||
@ -123,7 +123,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
} else {
|
||||
// Route to the qvm kernel
|
||||
@ -150,7 +150,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&D, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Route to the qmm_n kernel
|
||||
@ -188,7 +188,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&O, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,7 +74,7 @@ void all_reduce_dispatch(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Allocate intermediate array to store partial reduction results
|
||||
@ -88,7 +88,7 @@ void all_reduce_dispatch(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(intermediate, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Second pass to reduce intermediate reduction results written to DRAM
|
||||
compute_encoder.set_input_array(intermediate, 0);
|
||||
@ -108,7 +108,7 @@ void all_reduce_dispatch(
|
||||
nthreads = thread_group_size;
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[intermediates](MTL::CommandBuffer*) mutable {
|
||||
@ -217,7 +217,7 @@ void row_reduce_general_dispatch(
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Allocate intermediate array to store partial reduction results
|
||||
@ -239,7 +239,7 @@ void row_reduce_general_dispatch(
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Set up second dispatch
|
||||
reduction_size = non_row_reductions;
|
||||
@ -286,7 +286,7 @@ void row_reduce_general_dispatch(
|
||||
grid_dims = MTL::Size(n_threads, out.size(), 1);
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[intermediates](MTL::CommandBuffer*) mutable {
|
||||
@ -366,7 +366,7 @@ void strided_reduce_general_dispatch(
|
||||
compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11);
|
||||
|
||||
// Dispatch threads
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
return;
|
||||
}
|
||||
@ -435,7 +435,7 @@ void strided_reduce_general_dispatch(
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||
0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Allocate intermediate array to store reduction results from all thread
|
||||
@ -470,7 +470,7 @@ void strided_reduce_general_dispatch(
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||
0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Perform second pass of reductions
|
||||
// Reduce results of threadgroups along y, z from first pass, that
|
||||
@ -523,7 +523,7 @@ void strided_reduce_general_dispatch(
|
||||
grid_dims = MTL::Size(n_threads, out.size(), 1);
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[intermediates](MTL::CommandBuffer*) mutable {
|
||||
@ -585,7 +585,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_output_array(out, 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Reduce
|
||||
|
@ -83,7 +83,7 @@ void RoPE::eval_gpu(
|
||||
int dim2 = in.size() / mat_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -99,7 +99,7 @@ void sdpa_metal(
|
||||
|
||||
constexpr const uint tgroupMemorySize = 32768;
|
||||
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
{
|
||||
auto kernel_accum = d.get_kernel(kname_reduce.str());
|
||||
@ -114,7 +114,7 @@ void sdpa_metal(
|
||||
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
|
||||
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
|
||||
|
@ -77,7 +77,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
|
||||
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
kname << "strided_scan_";
|
||||
if (reverse_) {
|
||||
@ -119,7 +119,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
|
||||
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (copies.size() > 0) {
|
||||
|
@ -85,7 +85,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
|
@ -78,7 +78,7 @@ void single_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
@ -155,7 +155,7 @@ void multi_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merges
|
||||
@ -190,7 +190,7 @@ void multi_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merge
|
||||
@ -214,7 +214,7 @@ void multi_block_sort(
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user