Split encoders in non-concurrent context with a max ops per encoder (#1085)

* split encoders

* fix race
This commit is contained in:
Awni Hannun 2024-05-09 16:21:02 -07:00 committed by GitHub
parent b21242faf1
commit 06375e6605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 150 additions and 138 deletions

View File

@ -336,7 +336,7 @@ void Compiled::eval_gpu(
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@ -347,7 +347,7 @@ void Compiled::eval_gpu(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@ -59,7 +59,7 @@ void explicit_gemm_conv_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
@ -137,7 +137,7 @@ void explicit_gemm_conv_group_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
@ -247,7 +247,7 @@ void slow_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@ -512,7 +512,7 @@ void implicit_gemm_conv_2D_general_gpu(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
@ -641,7 +641,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
@ -689,7 +689,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}

View File

@ -126,7 +126,7 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
@ -135,7 +135,7 @@ void copy_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@ -25,6 +25,7 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
@ -37,7 +38,6 @@ auto load_device() {
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
@ -116,6 +116,33 @@ MTL::Library* load_library(
} // namespace
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
@ -130,9 +157,6 @@ Device::~Device() {
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
(*e.second)->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@ -169,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) {
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
return bit->second.second;
}
void Device::commit_command_buffer(int index) {
@ -200,25 +223,15 @@ void Device::commit_command_buffer(int index) {
}
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
(*eit->second)->endEncoding();
(*eit->second)->release();
encoder_map_.erase(eit);
}
encoder_map_.erase(index);
}
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder =
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
.first;
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
}
return *(eit->second);
}

View File

@ -37,8 +37,10 @@ using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false) {};
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@ -89,13 +91,25 @@ struct CommandEncoder {
}
}
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder() {
enc->endEncoding();
enc->release();
}
private:
void maybe_split();
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
};
@ -112,7 +126,6 @@ class Device {
};
void new_queue(int index);
MTL::CommandBuffer* new_command_buffer(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);

View File

@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });

View File

@ -107,7 +107,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Launch grid
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -216,7 +216,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
@ -286,7 +286,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@ -356,7 +356,7 @@ void steel_matmul_conv_groups(
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@ -468,7 +468,7 @@ void steel_matmul(
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@ -493,7 +493,7 @@ void steel_matmul(
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@ -581,7 +581,7 @@ void steel_matmul(
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@ -748,7 +748,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@ -968,7 +968,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@ -1038,7 +1038,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@ -1063,7 +1063,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@ -1160,7 +1160,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@ -1346,7 +1346,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(out_mask, 10);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@ -1566,7 +1566,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@ -1656,7 +1656,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
set_vector_bytes(compute_encoder, batch_strides_B, 15);
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(

View File

@ -27,24 +27,6 @@ int max_ops_per_buffer() {
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
MTL::CommandBuffer* increment_command_buffer(Stream s) {
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
if (command_buffer == nullptr ||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
if (command_buffer != nullptr) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
d.commit_command_buffer(s.index);
}
command_buffer = d.new_command_buffer(s.index);
}
d.increment_command_buffer_ops(s.index);
return command_buffer;
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
@ -58,7 +40,10 @@ std::function<void()> make_task(array arr, bool signal) {
auto task = [arr = std::move(arr), signal]() mutable {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
d.increment_command_buffer_ops(s.index);
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
@ -91,11 +76,13 @@ std::function<void()> make_task(array arr, bool signal) {
arr.detach();
}
if (signal) {
metal::device(s.device).end_encoding(s.index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.end_encoding(s.index);
if (signal) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
}
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers), event = arr.event()](
@ -103,7 +90,8 @@ std::function<void()> make_task(array arr, bool signal) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
metal::device(s.device).commit_command_buffer(s.index);
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
@ -120,14 +108,12 @@ std::function<void()> make_synchronize_task(
return [s, p = std::move(p)]() {
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
if (cb == nullptr) {
cb = d.new_command_buffer(s.index);
} else {
d.end_encoding(s.index);
}
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
p->set_value();
};
}

View File

@ -89,7 +89,7 @@ void RMSNorm::eval_gpu(
compute_encoder->setThreadgroupMemoryLength(
16 * 8, 0); // minimum of 16 bytes
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@ -190,7 +190,7 @@ void RMSNormVJP::eval_gpu(
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
ReductionPlan plan(
@ -282,7 +282,7 @@ void LayerNorm::eval_gpu(
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@ -401,7 +401,7 @@ void LayerNormVJP::eval_gpu(
compute_encoder->setBytes(&eps_, sizeof(float), 5);
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (gw.ndim() == 1 && gw.size() == axis_size) {

View File

@ -107,7 +107,7 @@ void binary_op(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
@ -117,7 +117,7 @@ void binary_op(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@ -201,7 +201,7 @@ void binary_op(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
@ -212,7 +212,7 @@ void binary_op(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@ -288,7 +288,7 @@ void ternary_op(
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
@ -298,7 +298,7 @@ void ternary_op(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@ -351,7 +351,7 @@ void unary_op(
int ndim = in.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace
@ -428,7 +428,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
}
compute_encoder.set_output_array(out, 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -523,7 +523,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@ -834,7 +834,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
keys.strides().data(), keys.ndim() * sizeof(size_t), 6);
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@ -65,7 +65,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmv kernel
@ -92,7 +92,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_t kernel
@ -123,7 +123,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
} else {
// Route to the qvm kernel
@ -150,7 +150,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_n kernel
@ -188,7 +188,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}

View File

@ -74,7 +74,7 @@ void all_reduce_dispatch(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
@ -88,7 +88,7 @@ void all_reduce_dispatch(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Second pass to reduce intermediate reduction results written to DRAM
compute_encoder.set_input_array(intermediate, 0);
@ -108,7 +108,7 @@ void all_reduce_dispatch(
nthreads = thread_group_size;
group_dims = MTL::Size(thread_group_size, 1, 1);
grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
@ -217,7 +217,7 @@ void row_reduce_general_dispatch(
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
@ -239,7 +239,7 @@ void row_reduce_general_dispatch(
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Set up second dispatch
reduction_size = non_row_reductions;
@ -286,7 +286,7 @@ void row_reduce_general_dispatch(
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
@ -366,7 +366,7 @@ void strided_reduce_general_dispatch(
compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11);
// Dispatch threads
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
return;
}
@ -435,7 +435,7 @@ void strided_reduce_general_dispatch(
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} else {
// Allocate intermediate array to store reduction results from all thread
@ -470,7 +470,7 @@ void strided_reduce_general_dispatch(
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Perform second pass of reductions
// Reduce results of threadgroups along y, z from first pass, that
@ -523,7 +523,7 @@ void strided_reduce_general_dispatch(
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
@ -585,7 +585,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_output_array(out, 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
// Reduce

View File

@ -83,7 +83,7 @@ void RoPE::eval_gpu(
int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@ -99,7 +99,7 @@ void sdpa_metal(
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
@ -114,7 +114,7 @@ void sdpa_metal(
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });

View File

@ -77,7 +77,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
kname << "strided_scan_";
if (reverse_) {
@ -119,7 +119,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (copies.size() > 0) {

View File

@ -85,7 +85,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });

View File

@ -78,7 +78,7 @@ void single_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
template <bool ARGSORT>
@ -155,7 +155,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do merges
@ -190,7 +190,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do merge
@ -214,7 +214,7 @@ void multi_block_sort(
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}