diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index e3da32c4b..8b11daa3a 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -336,7 +336,7 @@ void Compiled::eval_gpu( MTL::Size grid_dims(nthreads, 1, 1); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; @@ -347,7 +347,7 @@ void Compiled::eval_gpu( } auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index c8fd95c1a..165d66050 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -59,7 +59,7 @@ void explicit_gemm_conv_ND_gpu( MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); // Reshape weight std::vector wt_reshape{implicit_K, implicit_N}; @@ -137,7 +137,7 @@ void explicit_gemm_conv_group_ND_gpu( MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); // Transpose kernel weights so that we can slice them by contiguous chunks // of channel groups. @@ -247,7 +247,7 @@ void slow_conv_2D_gpu( compute_encoder.set_output_array(out, 2); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } void implicit_gemm_conv_2D_gpu( @@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu( compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); // Launch kernel - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } void implicit_gemm_conv_2D_general_gpu( @@ -512,7 +512,7 @@ void implicit_gemm_conv_2D_general_gpu( base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7); // Launch kernel - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } void winograd_conv_2D_gpu( @@ -613,7 +613,7 @@ void winograd_conv_2D_gpu( MTL::Size group_dims = MTL::Size(32, bo, 1); MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } // Do input transform @@ -641,7 +641,7 @@ void winograd_conv_2D_gpu( MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } // Do batched gemm @@ -689,7 +689,7 @@ void winograd_conv_2D_gpu( MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index fd4e920f6..63ada6c0e 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -126,7 +126,7 @@ void copy_gpu_inplace( auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { size_t nthreads = out.data_size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); @@ -135,7 +135,7 @@ void copy_gpu_inplace( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 155fdf356..f6db67c68 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -25,6 +25,7 @@ namespace { // TODO nicer way to set this or possibly expose as an environment variable constexpr int MAX_BUFFERS_PER_QUEUE = 12; +constexpr int MAX_DISPATCHES_PER_ENCODER = 2; constexpr const char* default_mtllib_path = METAL_PATH; @@ -37,7 +38,6 @@ auto load_device() { } return device; } - std::pair load_library_from_path( MTL::Device* device, const char* path) { @@ -116,6 +116,33 @@ MTL::Library* load_library( } // namespace +void CommandEncoder::dispatchThreadgroups( + MTL::Size grid_dims, + MTL::Size group_dims) { + num_dispatches++; + enc->dispatchThreadgroups(grid_dims, group_dims); + maybe_split(); +} + +void CommandEncoder::dispatchThreads( + MTL::Size grid_dims, + MTL::Size group_dims) { + num_dispatches++; + enc->dispatchThreads(grid_dims, group_dims); + maybe_split(); +} + +void CommandEncoder::maybe_split() { + if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) { + enc->endEncoding(); + enc->release(); + num_dispatches = 0; + outputs.clear(); + enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); + enc->retain(); + } +} + Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); @@ -130,9 +157,6 @@ Device::~Device() { for (auto& b : buffer_map_) { b.second.second->release(); } - for (auto& e : encoder_map_) { - (*e.second)->release(); - } for (auto& k : kernel_map_) { k.second->release(); } @@ -169,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) { MTL::CommandBuffer* Device::get_command_buffer(int index) { auto bit = buffer_map_.find(index); - return (bit == buffer_map_.end()) ? nullptr : bit->second.second; -} + if (bit == buffer_map_.end()) { + auto qit = queue_map_.find(index); + if (qit == queue_map_.end()) { + throw std::runtime_error( + "[metal::Device] Attempting to get command buffer for invalid queue."); + } -MTL::CommandBuffer* Device::new_command_buffer(int index) { - auto qit = queue_map_.find(index); - if (qit == queue_map_.end()) { - throw std::runtime_error( - "[metal::Device] Attempting to get command buffer for invalid queue."); + auto cb = qit->second->commandBufferWithUnretainedReferences(); + + if (!cb) { + throw std::runtime_error( + "[metal::Device] Unable to create new command buffer"); + } + + // Increment ref count so the buffer is not garbage collected + cb->retain(); + + bit = buffer_map_.insert({index, {0, cb}}).first; } - - auto cb = qit->second->commandBufferWithUnretainedReferences(); - - if (!cb) { - throw std::runtime_error( - "[metal::Device] Unable to create new command buffer"); - } - - // Increment ref count so the buffer is not garbage collected - cb->retain(); - - return buffer_map_.insert({index, {0, cb}}).first->second.second; + return bit->second.second; } void Device::commit_command_buffer(int index) { @@ -200,25 +223,15 @@ void Device::commit_command_buffer(int index) { } void Device::end_encoding(int index) { - auto eit = encoder_map_.find(index); - if (eit != encoder_map_.end()) { - (*eit->second)->endEncoding(); - (*eit->second)->release(); - encoder_map_.erase(eit); - } + encoder_map_.erase(index); } CommandEncoder& Device::get_command_encoder(int index) { auto eit = encoder_map_.find(index); if (eit == encoder_map_.end()) { auto cb = get_command_buffer(index); - auto compute_encoder = - cb->computeCommandEncoder(MTL::DispatchTypeConcurrent); - // Increment ref count so the buffer is not garbage collected - compute_encoder->retain(); - eit = encoder_map_ - .emplace(index, std::make_unique(compute_encoder)) - .first; + eit = + encoder_map_.emplace(index, std::make_unique(cb)).first; } return *(eit->second); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 9e5518af6..fa87e4283 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -37,8 +37,10 @@ using MTLFCList = std::vector>; struct CommandEncoder { - CommandEncoder(MTL::ComputeCommandEncoder* enc) - : enc(enc), concurrent(false) {}; + CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { + enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); + enc->retain(); + }; CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; @@ -89,13 +91,25 @@ struct CommandEncoder { } } + void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); + void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); + ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } + ~CommandEncoder() { + enc->endEncoding(); + enc->release(); + } + private: + void maybe_split(); + + int num_dispatches{0}; + MTL::CommandBuffer* cbuf; MTL::ComputeCommandEncoder* enc; - bool concurrent; + bool concurrent{false}; std::unordered_set outputs; std::unordered_set concurrent_outputs; }; @@ -112,7 +126,6 @@ class Device { }; void new_queue(int index); - MTL::CommandBuffer* new_command_buffer(int index); MTL::CommandBuffer* get_command_buffer(int index); int get_command_buffer_ops(int index); void increment_command_buffer_ops(int index); diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 3c851d23f..9f64cefd6 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector& inputs, array& out) { auto group_dims = MTL::Size(1, m, 1); auto grid_dims = MTL::Size(batch, m, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index b40c9c8c9..cb1faf058 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -107,7 +107,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { } // Launch grid - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -216,7 +216,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Launch grid MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { // Collect all idx shapes and strides into one place @@ -286,7 +286,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Launch grid MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index dc99ead61..f82d315ba 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -356,7 +356,7 @@ void steel_matmul_conv_groups( compute_encoder->setBytes( batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies d.get_command_buffer(s.index)->addCompletedHandler( @@ -468,7 +468,7 @@ void steel_matmul( compute_encoder.set_output_array(C_split, 2); compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Do accum kernel { @@ -493,7 +493,7 @@ void steel_matmul( MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( @@ -581,7 +581,7 @@ void steel_matmul( compute_encoder->setBytes( batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies d.get_command_buffer(s.index)->addCompletedHandler( @@ -748,7 +748,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes( batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); @@ -968,7 +968,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { int bias_stride = c.strides()[c.ndim() - 1]; compute_encoder->setBytes(&bias_stride, sizeof(int), 14); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); @@ -1038,7 +1038,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(C_split, 2); compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Do accum kernel { @@ -1063,7 +1063,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( @@ -1160,7 +1160,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes( batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); @@ -1346,7 +1346,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(out_mask, 10); set_vector_bytes(compute_encoder, mask_strides, 13); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies d.get_command_buffer(s.index)->addCompletedHandler( @@ -1566,7 +1566,7 @@ void BlockSparseMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); @@ -1656,7 +1656,7 @@ void BlockSparseMM::eval_gpu(const std::vector& inputs, array& out) { set_vector_bytes(compute_encoder, batch_strides_B, 15); set_vector_bytes(compute_encoder, operand_batch_ndim, 16); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Clear copies d.get_command_buffer(s.index)->addCompletedHandler( diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 2cdbc49a5..3afe47159 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -27,24 +27,6 @@ int max_ops_per_buffer() { #define MAX_OPS_PER_BUFFER max_ops_per_buffer() -MTL::CommandBuffer* increment_command_buffer(Stream s) { - auto& d = metal::device(s.device); - auto command_buffer = d.get_command_buffer(s.index); - if (command_buffer == nullptr || - d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { - if (command_buffer != nullptr) { - d.end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); }); - d.commit_command_buffer(s.index); - } - command_buffer = d.new_command_buffer(s.index); - } - d.increment_command_buffer_ops(s.index); - return command_buffer; -} - inline void check_error(MTL::CommandBuffer* cbuf) { if (cbuf->status() == MTL::CommandBufferStatusError) { std::ostringstream msg; @@ -58,7 +40,10 @@ std::function make_task(array arr, bool signal) { auto task = [arr = std::move(arr), signal]() mutable { auto pool = new_scoped_memory_pool(); auto s = arr.primitive().stream(); - auto command_buffer = increment_command_buffer(s); + auto& d = metal::device(s.device); + auto command_buffer = d.get_command_buffer(s.index); + d.increment_command_buffer_ops(s.index); + for (auto& input : arr.inputs()) { if (input.event().valid() && input.event().stream() != arr.primitive().stream()) { @@ -91,11 +76,13 @@ std::function make_task(array arr, bool signal) { arr.detach(); } - if (signal) { - metal::device(s.device).end_encoding(s.index); - command_buffer->encodeSignalEvent( - static_cast(arr.event().raw_event().get()), - arr.event().value()); + if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { + d.end_encoding(s.index); + if (signal) { + command_buffer->encodeSignalEvent( + static_cast(arr.event().raw_event().get()), + arr.event().value()); + } scheduler::notify_new_task(s); command_buffer->addCompletedHandler( [s, buffers = std::move(buffers), event = arr.event()]( @@ -103,7 +90,8 @@ std::function make_task(array arr, bool signal) { scheduler::notify_task_completion(s); check_error(cbuf); }); - metal::device(s.device).commit_command_buffer(s.index); + d.commit_command_buffer(s.index); + d.get_command_buffer(s.index); } else { command_buffer->addCompletedHandler( [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { @@ -120,14 +108,12 @@ std::function make_synchronize_task( return [s, p = std::move(p)]() { auto& d = metal::device(s.device); auto cb = d.get_command_buffer(s.index); - if (cb == nullptr) { - cb = d.new_command_buffer(s.index); - } else { - d.end_encoding(s.index); - } + cb->retain(); + d.end_encoding(s.index); d.commit_command_buffer(s.index); cb->waitUntilCompleted(); check_error(cb); + cb->release(); p->set_value(); }; } diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 61728b5f9..5f728cd33 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -89,7 +89,7 @@ void RMSNorm::eval_gpu( compute_encoder->setThreadgroupMemoryLength( 16 * 8, 0); // minimum of 16 bytes compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); @@ -190,7 +190,7 @@ void RMSNormVJP::eval_gpu( compute_encoder->setBytes(&eps_, sizeof(float), 5); compute_encoder->setBytes(&axis_size, sizeof(int), 6); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } ReductionPlan plan( @@ -282,7 +282,7 @@ void LayerNorm::eval_gpu( compute_encoder->setBytes(&axis_size, sizeof(int), 5); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6); compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); @@ -401,7 +401,7 @@ void LayerNormVJP::eval_gpu( compute_encoder->setBytes(&eps_, sizeof(float), 5); compute_encoder->setBytes(&axis_size, sizeof(int), 6); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } if (gw.ndim() == 1 && gw.size() == axis_size) { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index c4ed2618b..06e9735a5 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -107,7 +107,7 @@ void binary_op( } auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { // Launch a 1D grid of threads size_t nthreads = out.data_size(); @@ -117,7 +117,7 @@ void binary_op( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } @@ -201,7 +201,7 @@ void binary_op( } auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { // Launch a 1D grid of threads size_t nthreads = @@ -212,7 +212,7 @@ void binary_op( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } @@ -288,7 +288,7 @@ void ternary_op( } MTL::Size group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { // Launch a 1D grid of threads size_t nthreads = out.data_size(); @@ -298,7 +298,7 @@ void ternary_op( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } @@ -351,7 +351,7 @@ void unary_op( int ndim = in.ndim(); compute_encoder->setBytes(&ndim, sizeof(int), 4); } - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } // namespace @@ -428,7 +428,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { } compute_encoder.set_output_array(out, 2); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } void ArcCos::eval_gpu(const std::vector& inputs, array& out) { @@ -523,7 +523,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&ndim, sizeof(size_t), 5); compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6); compute_encoder->setBytes(&axis_size, sizeof(size_t), 7); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } @@ -834,7 +834,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { keys.strides().data(), keys.ndim() * sizeof(size_t), 6); } - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } void Reshape::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index eb060e7e9..4f48f9ce8 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -65,7 +65,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&D, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } // Route to the qmv kernel @@ -92,7 +92,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&D, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } // Route to the qmm_t kernel @@ -123,7 +123,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&O, sizeof(int), 6); compute_encoder->setBytes(&D, sizeof(int), 7); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } } else { // Route to the qvm kernel @@ -150,7 +150,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&D, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } // Route to the qmm_n kernel @@ -188,7 +188,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&O, sizeof(int), 6); compute_encoder->setBytes(&D, sizeof(int), 7); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index fc0d1ca1a..f4a638bc4 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -74,7 +74,7 @@ void all_reduce_dispatch( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&in_size, sizeof(size_t), 2); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { // Allocate intermediate array to store partial reduction results @@ -88,7 +88,7 @@ void all_reduce_dispatch( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); compute_encoder->setBytes(&in_size, sizeof(size_t), 2); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); // Second pass to reduce intermediate reduction results written to DRAM compute_encoder.set_input_array(intermediate, 0); @@ -108,7 +108,7 @@ void all_reduce_dispatch( nthreads = thread_group_size; group_dims = MTL::Size(thread_group_size, 1, 1); grid_dims = MTL::Size(nthreads, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [intermediates](MTL::CommandBuffer*) mutable { @@ -217,7 +217,7 @@ void row_reduce_general_dispatch( compute_encoder->setBytes( strides.data(), strides.size() * sizeof(size_t), 6); compute_encoder->setBytes(&ndim, sizeof(int), 7); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { // Allocate intermediate array to store partial reduction results @@ -239,7 +239,7 @@ void row_reduce_general_dispatch( compute_encoder->setBytes( strides.data(), strides.size() * sizeof(size_t), 6); compute_encoder->setBytes(&ndim, sizeof(int), 7); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); // Set up second dispatch reduction_size = non_row_reductions; @@ -286,7 +286,7 @@ void row_reduce_general_dispatch( grid_dims = MTL::Size(n_threads, out.size(), 1); group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [intermediates](MTL::CommandBuffer*) mutable { @@ -366,7 +366,7 @@ void strided_reduce_general_dispatch( compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11); // Dispatch threads - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); return; } @@ -435,7 +435,7 @@ void strided_reduce_general_dispatch( compute_encoder->setThreadgroupMemoryLength( safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), 0); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } else { // Allocate intermediate array to store reduction results from all thread @@ -470,7 +470,7 @@ void strided_reduce_general_dispatch( compute_encoder->setThreadgroupMemoryLength( safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), 0); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); // Perform second pass of reductions // Reduce results of threadgroups along y, z from first pass, that @@ -523,7 +523,7 @@ void strided_reduce_general_dispatch( grid_dims = MTL::Size(n_threads, out.size(), 1); group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [intermediates](MTL::CommandBuffer*) mutable { @@ -585,7 +585,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_output_array(out, 0); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } // Reduce diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index 1151f8c43..c19ad52a4 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -83,7 +83,7 @@ void RoPE::eval_gpu( int dim2 = in.size() / mat_size; auto group_dims = get_block_dims(dim0, dim1, dim2); auto grid_dims = MTL::Size(dim0, dim1, dim2); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 5b5a68870..0ded93397 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -99,7 +99,7 @@ void sdpa_metal( constexpr const uint tgroupMemorySize = 32768; compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); { auto kernel_accum = d.get_kernel(kname_reduce.str()); @@ -114,7 +114,7 @@ void sdpa_metal( MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch); MTL::Size group_dims_reduce = MTL::Size(128, 1, 1); - compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce); + compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce); d.get_command_buffer(s.index)->addCompletedHandler( [temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); }); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 94757b1e7..44c8fe5db 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -77,7 +77,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { static_cast(kernel->maxTotalThreadsPerThreadgroup())); MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } else { kname << "strided_scan_"; if (reverse_) { @@ -119,7 +119,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x; MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1); MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } if (copies.size() > 0) { diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 1fbc1e00c..41173dd16 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -85,7 +85,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&axis_size, sizeof(int), 2); - compute_encoder->dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 9f53779a0..528f2e951 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -78,7 +78,7 @@ void single_block_sort( MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } template @@ -155,7 +155,7 @@ void multi_block_sort( MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } // Do merges @@ -190,7 +190,7 @@ void multi_block_sort( MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } // Do merge @@ -214,7 +214,7 @@ void multi_block_sort( MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); } }