mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
This commit is contained in:
		| @@ -336,7 +336,7 @@ void Compiled::eval_gpu( | ||||
|     MTL::Size grid_dims(nthreads, 1, 1); | ||||
|     MTL::Size group_dims( | ||||
|         std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; | ||||
|     size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; | ||||
| @@ -347,7 +347,7 @@ void Compiled::eval_gpu( | ||||
|     } | ||||
|     auto group_dims = get_block_dims(dim0, dim1, rest); | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -59,7 +59,7 @@ void explicit_gemm_conv_ND_gpu( | ||||
|   MTL::Size grid_dims = MTL::Size( | ||||
|       conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); | ||||
|  | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   // Reshape weight | ||||
|   std::vector<int> wt_reshape{implicit_K, implicit_N}; | ||||
| @@ -137,7 +137,7 @@ void explicit_gemm_conv_group_ND_gpu( | ||||
|   MTL::Size grid_dims = MTL::Size( | ||||
|       conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); | ||||
|  | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   // Transpose kernel weights so that we can slice them by contiguous chunks | ||||
|   // of channel groups. | ||||
| @@ -247,7 +247,7 @@ void slow_conv_2D_gpu( | ||||
|   compute_encoder.set_output_array(out, 2); | ||||
|  | ||||
|   compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void implicit_gemm_conv_2D_gpu( | ||||
| @@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu( | ||||
|   compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); | ||||
|  | ||||
|   // Launch kernel | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void implicit_gemm_conv_2D_general_gpu( | ||||
| @@ -512,7 +512,7 @@ void implicit_gemm_conv_2D_general_gpu( | ||||
|       base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7); | ||||
|  | ||||
|   // Launch kernel | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void winograd_conv_2D_gpu( | ||||
| @@ -613,7 +613,7 @@ void winograd_conv_2D_gpu( | ||||
|     MTL::Size group_dims = MTL::Size(32, bo, 1); | ||||
|     MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|   } | ||||
|  | ||||
|   // Do input transform | ||||
| @@ -641,7 +641,7 @@ void winograd_conv_2D_gpu( | ||||
|     MTL::Size group_dims = MTL::Size(32, wn, wm); | ||||
|     MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|   } | ||||
|  | ||||
|   // Do batched gemm | ||||
| @@ -689,7 +689,7 @@ void winograd_conv_2D_gpu( | ||||
|     MTL::Size group_dims = MTL::Size(32, wn, wm); | ||||
|     MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -126,7 +126,7 @@ void copy_gpu_inplace( | ||||
|  | ||||
|     auto group_dims = get_block_dims(dim0, dim1, rest); | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     size_t nthreads = out.data_size(); | ||||
|     MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); | ||||
| @@ -135,7 +135,7 @@ void copy_gpu_inplace( | ||||
|       thread_group_size = nthreads; | ||||
|     } | ||||
|     MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -25,6 +25,7 @@ namespace { | ||||
|  | ||||
| // TODO nicer way to set this or possibly expose as an environment variable | ||||
| constexpr int MAX_BUFFERS_PER_QUEUE = 12; | ||||
| constexpr int MAX_DISPATCHES_PER_ENCODER = 2; | ||||
|  | ||||
| constexpr const char* default_mtllib_path = METAL_PATH; | ||||
|  | ||||
| @@ -37,7 +38,6 @@ auto load_device() { | ||||
|   } | ||||
|   return device; | ||||
| } | ||||
|  | ||||
| std::pair<MTL::Library*, NS::Error*> load_library_from_path( | ||||
|     MTL::Device* device, | ||||
|     const char* path) { | ||||
| @@ -116,6 +116,33 @@ MTL::Library* load_library( | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| void CommandEncoder::dispatchThreadgroups( | ||||
|     MTL::Size grid_dims, | ||||
|     MTL::Size group_dims) { | ||||
|   num_dispatches++; | ||||
|   enc->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   maybe_split(); | ||||
| } | ||||
|  | ||||
| void CommandEncoder::dispatchThreads( | ||||
|     MTL::Size grid_dims, | ||||
|     MTL::Size group_dims) { | ||||
|   num_dispatches++; | ||||
|   enc->dispatchThreads(grid_dims, group_dims); | ||||
|   maybe_split(); | ||||
| } | ||||
|  | ||||
| void CommandEncoder::maybe_split() { | ||||
|   if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) { | ||||
|     enc->endEncoding(); | ||||
|     enc->release(); | ||||
|     num_dispatches = 0; | ||||
|     outputs.clear(); | ||||
|     enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); | ||||
|     enc->retain(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| Device::Device() { | ||||
|   auto pool = new_scoped_memory_pool(); | ||||
|   device_ = load_device(); | ||||
| @@ -130,9 +157,6 @@ Device::~Device() { | ||||
|   for (auto& b : buffer_map_) { | ||||
|     b.second.second->release(); | ||||
|   } | ||||
|   for (auto& e : encoder_map_) { | ||||
|     (*e.second)->release(); | ||||
|   } | ||||
|   for (auto& k : kernel_map_) { | ||||
|     k.second->release(); | ||||
|   } | ||||
| @@ -169,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) { | ||||
|  | ||||
| MTL::CommandBuffer* Device::get_command_buffer(int index) { | ||||
|   auto bit = buffer_map_.find(index); | ||||
|   return (bit == buffer_map_.end()) ? nullptr : bit->second.second; | ||||
| } | ||||
|   if (bit == buffer_map_.end()) { | ||||
|     auto qit = queue_map_.find(index); | ||||
|     if (qit == queue_map_.end()) { | ||||
|       throw std::runtime_error( | ||||
|           "[metal::Device] Attempting to get command buffer for invalid queue."); | ||||
|     } | ||||
|  | ||||
| MTL::CommandBuffer* Device::new_command_buffer(int index) { | ||||
|   auto qit = queue_map_.find(index); | ||||
|   if (qit == queue_map_.end()) { | ||||
|     throw std::runtime_error( | ||||
|         "[metal::Device] Attempting to get command buffer for invalid queue."); | ||||
|     auto cb = qit->second->commandBufferWithUnretainedReferences(); | ||||
|  | ||||
|     if (!cb) { | ||||
|       throw std::runtime_error( | ||||
|           "[metal::Device] Unable to create new command buffer"); | ||||
|     } | ||||
|  | ||||
|     // Increment ref count so the buffer is not garbage collected | ||||
|     cb->retain(); | ||||
|  | ||||
|     bit = buffer_map_.insert({index, {0, cb}}).first; | ||||
|   } | ||||
|  | ||||
|   auto cb = qit->second->commandBufferWithUnretainedReferences(); | ||||
|  | ||||
|   if (!cb) { | ||||
|     throw std::runtime_error( | ||||
|         "[metal::Device] Unable to create new command buffer"); | ||||
|   } | ||||
|  | ||||
|   // Increment ref count so the buffer is not garbage collected | ||||
|   cb->retain(); | ||||
|  | ||||
|   return buffer_map_.insert({index, {0, cb}}).first->second.second; | ||||
|   return bit->second.second; | ||||
| } | ||||
|  | ||||
| void Device::commit_command_buffer(int index) { | ||||
| @@ -200,25 +223,15 @@ void Device::commit_command_buffer(int index) { | ||||
| } | ||||
|  | ||||
| void Device::end_encoding(int index) { | ||||
|   auto eit = encoder_map_.find(index); | ||||
|   if (eit != encoder_map_.end()) { | ||||
|     (*eit->second)->endEncoding(); | ||||
|     (*eit->second)->release(); | ||||
|     encoder_map_.erase(eit); | ||||
|   } | ||||
|   encoder_map_.erase(index); | ||||
| } | ||||
|  | ||||
| CommandEncoder& Device::get_command_encoder(int index) { | ||||
|   auto eit = encoder_map_.find(index); | ||||
|   if (eit == encoder_map_.end()) { | ||||
|     auto cb = get_command_buffer(index); | ||||
|     auto compute_encoder = | ||||
|         cb->computeCommandEncoder(MTL::DispatchTypeConcurrent); | ||||
|     // Increment ref count so the buffer is not garbage collected | ||||
|     compute_encoder->retain(); | ||||
|     eit = encoder_map_ | ||||
|               .emplace(index, std::make_unique<CommandEncoder>(compute_encoder)) | ||||
|               .first; | ||||
|     eit = | ||||
|         encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first; | ||||
|   } | ||||
|   return *(eit->second); | ||||
| } | ||||
|   | ||||
| @@ -37,8 +37,10 @@ using MTLFCList = | ||||
|     std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>; | ||||
|  | ||||
| struct CommandEncoder { | ||||
|   CommandEncoder(MTL::ComputeCommandEncoder* enc) | ||||
|       : enc(enc), concurrent(false) {}; | ||||
|   CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { | ||||
|     enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); | ||||
|     enc->retain(); | ||||
|   }; | ||||
|   CommandEncoder(const CommandEncoder&) = delete; | ||||
|   CommandEncoder& operator=(const CommandEncoder&) = delete; | ||||
|  | ||||
| @@ -89,13 +91,25 @@ struct CommandEncoder { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); | ||||
|   void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); | ||||
|  | ||||
|   ConcurrentContext start_concurrent() { | ||||
|     return ConcurrentContext(*this); | ||||
|   } | ||||
|  | ||||
|   ~CommandEncoder() { | ||||
|     enc->endEncoding(); | ||||
|     enc->release(); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   void maybe_split(); | ||||
|  | ||||
|   int num_dispatches{0}; | ||||
|   MTL::CommandBuffer* cbuf; | ||||
|   MTL::ComputeCommandEncoder* enc; | ||||
|   bool concurrent; | ||||
|   bool concurrent{false}; | ||||
|   std::unordered_set<MTL::Resource*> outputs; | ||||
|   std::unordered_set<MTL::Resource*> concurrent_outputs; | ||||
| }; | ||||
| @@ -112,7 +126,6 @@ class Device { | ||||
|   }; | ||||
|  | ||||
|   void new_queue(int index); | ||||
|   MTL::CommandBuffer* new_command_buffer(int index); | ||||
|   MTL::CommandBuffer* get_command_buffer(int index); | ||||
|   int get_command_buffer_ops(int index); | ||||
|   void increment_command_buffer_ops(int index); | ||||
|   | ||||
| @@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|  | ||||
|     auto group_dims = MTL::Size(1, m, 1); | ||||
|     auto grid_dims = MTL::Size(batch, m, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|       [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
|   | ||||
| @@ -107,7 +107,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   } | ||||
|  | ||||
|   // Launch grid | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
| @@ -216,7 +216,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     // Launch grid | ||||
|     MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); | ||||
|     MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   } else { | ||||
|     // Collect all idx shapes and strides into one place | ||||
| @@ -286,7 +286,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     // Launch grid | ||||
|     MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); | ||||
|     MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -356,7 +356,7 @@ void steel_matmul_conv_groups( | ||||
|   compute_encoder->setBytes( | ||||
|       batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); | ||||
|  | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   // Clear copies | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
| @@ -468,7 +468,7 @@ void steel_matmul( | ||||
|     compute_encoder.set_output_array(C_split, 2); | ||||
|  | ||||
|     compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|     // Do accum kernel | ||||
|     { | ||||
| @@ -493,7 +493,7 @@ void steel_matmul( | ||||
|       MTL::Size grid_dims = MTL::Size(N, M, 1); | ||||
|       MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); | ||||
|  | ||||
|       compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|     } | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
| @@ -581,7 +581,7 @@ void steel_matmul( | ||||
|   compute_encoder->setBytes( | ||||
|       batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); | ||||
|  | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   // Clear copies | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
| @@ -748,7 +748,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     compute_encoder->setBytes( | ||||
|         batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
| @@ -968,7 +968,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     int bias_stride = c.strides()[c.ndim() - 1]; | ||||
|     compute_encoder->setBytes(&bias_stride, sizeof(int), 14); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
| @@ -1038,7 +1038,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     compute_encoder.set_output_array(C_split, 2); | ||||
|  | ||||
|     compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|     // Do accum kernel | ||||
|     { | ||||
| @@ -1063,7 +1063,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       MTL::Size grid_dims = MTL::Size(N, M, 1); | ||||
|       MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); | ||||
|  | ||||
|       compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|     } | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
| @@ -1160,7 +1160,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   compute_encoder->setBytes( | ||||
|       batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); | ||||
|  | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|       [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
| @@ -1346,7 +1346,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   compute_encoder.set_input_array(out_mask, 10); | ||||
|   set_vector_bytes(compute_encoder, mask_strides, 13); | ||||
|  | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   // Clear copies | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
| @@ -1566,7 +1566,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); | ||||
|     compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
| @@ -1656,7 +1656,7 @@ void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   set_vector_bytes(compute_encoder, batch_strides_B, 15); | ||||
|   set_vector_bytes(compute_encoder, operand_batch_ndim, 16); | ||||
|  | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   // Clear copies | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|   | ||||
| @@ -27,24 +27,6 @@ int max_ops_per_buffer() { | ||||
|  | ||||
| #define MAX_OPS_PER_BUFFER max_ops_per_buffer() | ||||
|  | ||||
| MTL::CommandBuffer* increment_command_buffer(Stream s) { | ||||
|   auto& d = metal::device(s.device); | ||||
|   auto command_buffer = d.get_command_buffer(s.index); | ||||
|   if (command_buffer == nullptr || | ||||
|       d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { | ||||
|     if (command_buffer != nullptr) { | ||||
|       d.end_encoding(s.index); | ||||
|       scheduler::notify_new_task(s); | ||||
|       command_buffer->addCompletedHandler( | ||||
|           [s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); }); | ||||
|       d.commit_command_buffer(s.index); | ||||
|     } | ||||
|     command_buffer = d.new_command_buffer(s.index); | ||||
|   } | ||||
|   d.increment_command_buffer_ops(s.index); | ||||
|   return command_buffer; | ||||
| } | ||||
|  | ||||
| inline void check_error(MTL::CommandBuffer* cbuf) { | ||||
|   if (cbuf->status() == MTL::CommandBufferStatusError) { | ||||
|     std::ostringstream msg; | ||||
| @@ -58,7 +40,10 @@ std::function<void()> make_task(array arr, bool signal) { | ||||
|   auto task = [arr = std::move(arr), signal]() mutable { | ||||
|     auto pool = new_scoped_memory_pool(); | ||||
|     auto s = arr.primitive().stream(); | ||||
|     auto command_buffer = increment_command_buffer(s); | ||||
|     auto& d = metal::device(s.device); | ||||
|     auto command_buffer = d.get_command_buffer(s.index); | ||||
|     d.increment_command_buffer_ops(s.index); | ||||
|  | ||||
|     for (auto& input : arr.inputs()) { | ||||
|       if (input.event().valid() && | ||||
|           input.event().stream() != arr.primitive().stream()) { | ||||
| @@ -91,11 +76,13 @@ std::function<void()> make_task(array arr, bool signal) { | ||||
|       arr.detach(); | ||||
|     } | ||||
|  | ||||
|     if (signal) { | ||||
|       metal::device(s.device).end_encoding(s.index); | ||||
|       command_buffer->encodeSignalEvent( | ||||
|           static_cast<MTL::Event*>(arr.event().raw_event().get()), | ||||
|           arr.event().value()); | ||||
|     if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { | ||||
|       d.end_encoding(s.index); | ||||
|       if (signal) { | ||||
|         command_buffer->encodeSignalEvent( | ||||
|             static_cast<MTL::Event*>(arr.event().raw_event().get()), | ||||
|             arr.event().value()); | ||||
|       } | ||||
|       scheduler::notify_new_task(s); | ||||
|       command_buffer->addCompletedHandler( | ||||
|           [s, buffers = std::move(buffers), event = arr.event()]( | ||||
| @@ -103,7 +90,8 @@ std::function<void()> make_task(array arr, bool signal) { | ||||
|             scheduler::notify_task_completion(s); | ||||
|             check_error(cbuf); | ||||
|           }); | ||||
|       metal::device(s.device).commit_command_buffer(s.index); | ||||
|       d.commit_command_buffer(s.index); | ||||
|       d.get_command_buffer(s.index); | ||||
|     } else { | ||||
|       command_buffer->addCompletedHandler( | ||||
|           [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { | ||||
| @@ -120,14 +108,12 @@ std::function<void()> make_synchronize_task( | ||||
|   return [s, p = std::move(p)]() { | ||||
|     auto& d = metal::device(s.device); | ||||
|     auto cb = d.get_command_buffer(s.index); | ||||
|     if (cb == nullptr) { | ||||
|       cb = d.new_command_buffer(s.index); | ||||
|     } else { | ||||
|       d.end_encoding(s.index); | ||||
|     } | ||||
|     cb->retain(); | ||||
|     d.end_encoding(s.index); | ||||
|     d.commit_command_buffer(s.index); | ||||
|     cb->waitUntilCompleted(); | ||||
|     check_error(cb); | ||||
|     cb->release(); | ||||
|     p->set_value(); | ||||
|   }; | ||||
| } | ||||
|   | ||||
| @@ -89,7 +89,7 @@ void RMSNorm::eval_gpu( | ||||
|     compute_encoder->setThreadgroupMemoryLength( | ||||
|         16 * 8, 0); // minimum of 16 bytes | ||||
|     compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|       [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
| @@ -190,7 +190,7 @@ void RMSNormVJP::eval_gpu( | ||||
|     compute_encoder->setBytes(&eps_, sizeof(float), 5); | ||||
|     compute_encoder->setBytes(&axis_size, sizeof(int), 6); | ||||
|     compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|  | ||||
|   ReductionPlan plan( | ||||
| @@ -282,7 +282,7 @@ void LayerNorm::eval_gpu( | ||||
|     compute_encoder->setBytes(&axis_size, sizeof(int), 5); | ||||
|     compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6); | ||||
|     compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|       [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
| @@ -401,7 +401,7 @@ void LayerNormVJP::eval_gpu( | ||||
|     compute_encoder->setBytes(&eps_, sizeof(float), 5); | ||||
|     compute_encoder->setBytes(&axis_size, sizeof(int), 6); | ||||
|     compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|  | ||||
|   if (gw.ndim() == 1 && gw.size() == axis_size) { | ||||
|   | ||||
| @@ -107,7 +107,7 @@ void binary_op( | ||||
|     } | ||||
|     auto group_dims = get_block_dims(dim0, dim1, rest); | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     // Launch a 1D grid of threads | ||||
|     size_t nthreads = out.data_size(); | ||||
| @@ -117,7 +117,7 @@ void binary_op( | ||||
|       thread_group_size = nthreads; | ||||
|     } | ||||
|     MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -201,7 +201,7 @@ void binary_op( | ||||
|     } | ||||
|     auto group_dims = get_block_dims(dim0, dim1, rest); | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     // Launch a 1D grid of threads | ||||
|     size_t nthreads = | ||||
| @@ -212,7 +212,7 @@ void binary_op( | ||||
|       thread_group_size = nthreads; | ||||
|     } | ||||
|     MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -288,7 +288,7 @@ void ternary_op( | ||||
|     } | ||||
|     MTL::Size group_dims = get_block_dims(dim0, dim1, rest); | ||||
|     MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     // Launch a 1D grid of threads | ||||
|     size_t nthreads = out.data_size(); | ||||
| @@ -298,7 +298,7 @@ void ternary_op( | ||||
|       thread_group_size = nthreads; | ||||
|     } | ||||
|     MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -351,7 +351,7 @@ void unary_op( | ||||
|     int ndim = in.ndim(); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 4); | ||||
|   } | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
| @@ -428,7 +428,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   } | ||||
|  | ||||
|   compute_encoder.set_output_array(out, 2); | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
| @@ -523,7 +523,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     compute_encoder->setBytes(&ndim, sizeof(size_t), 5); | ||||
|     compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6); | ||||
|     compute_encoder->setBytes(&axis_size, sizeof(size_t), 7); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -834,7 +834,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|         keys.strides().data(), keys.ndim() * sizeof(size_t), 6); | ||||
|   } | ||||
|  | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|   | ||||
| @@ -65,7 +65,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       compute_encoder->setBytes(&D, sizeof(int), 5); | ||||
|       compute_encoder->setBytes(&O, sizeof(int), 6); | ||||
|  | ||||
|       compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|     } | ||||
|  | ||||
|     // Route to the qmv kernel | ||||
| @@ -92,7 +92,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       compute_encoder->setBytes(&D, sizeof(int), 5); | ||||
|       compute_encoder->setBytes(&O, sizeof(int), 6); | ||||
|  | ||||
|       compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|     } | ||||
|  | ||||
|     // Route to the qmm_t kernel | ||||
| @@ -123,7 +123,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       compute_encoder->setBytes(&O, sizeof(int), 6); | ||||
|       compute_encoder->setBytes(&D, sizeof(int), 7); | ||||
|  | ||||
|       compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|     } | ||||
|   } else { | ||||
|     // Route to the qvm kernel | ||||
| @@ -150,7 +150,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       compute_encoder->setBytes(&D, sizeof(int), 5); | ||||
|       compute_encoder->setBytes(&O, sizeof(int), 6); | ||||
|  | ||||
|       compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|     } | ||||
|  | ||||
|     // Route to the qmm_n kernel | ||||
| @@ -188,7 +188,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|       compute_encoder->setBytes(&O, sizeof(int), 6); | ||||
|       compute_encoder->setBytes(&D, sizeof(int), 7); | ||||
|  | ||||
|       compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   | ||||
| @@ -74,7 +74,7 @@ void all_reduce_dispatch( | ||||
|     compute_encoder.set_input_array(in, 0); | ||||
|     compute_encoder.set_output_array(out, 1); | ||||
|     compute_encoder->setBytes(&in_size, sizeof(size_t), 2); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   } else { | ||||
|     // Allocate intermediate array to store partial reduction results | ||||
| @@ -88,7 +88,7 @@ void all_reduce_dispatch( | ||||
|     compute_encoder.set_input_array(in, 0); | ||||
|     compute_encoder.set_output_array(intermediate, 1); | ||||
|     compute_encoder->setBytes(&in_size, sizeof(size_t), 2); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     // Second pass to reduce intermediate reduction results written to DRAM | ||||
|     compute_encoder.set_input_array(intermediate, 0); | ||||
| @@ -108,7 +108,7 @@ void all_reduce_dispatch( | ||||
|     nthreads = thread_group_size; | ||||
|     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     grid_dims = MTL::Size(nthreads, 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [intermediates](MTL::CommandBuffer*) mutable { | ||||
| @@ -217,7 +217,7 @@ void row_reduce_general_dispatch( | ||||
|     compute_encoder->setBytes( | ||||
|         strides.data(), strides.size() * sizeof(size_t), 6); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 7); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|   } else { | ||||
|     // Allocate intermediate array to store partial reduction results | ||||
| @@ -239,7 +239,7 @@ void row_reduce_general_dispatch( | ||||
|     compute_encoder->setBytes( | ||||
|         strides.data(), strides.size() * sizeof(size_t), 6); | ||||
|     compute_encoder->setBytes(&ndim, sizeof(int), 7); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     // Set up second dispatch | ||||
|     reduction_size = non_row_reductions; | ||||
| @@ -286,7 +286,7 @@ void row_reduce_general_dispatch( | ||||
|     grid_dims = MTL::Size(n_threads, out.size(), 1); | ||||
|     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|  | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [intermediates](MTL::CommandBuffer*) mutable { | ||||
| @@ -366,7 +366,7 @@ void strided_reduce_general_dispatch( | ||||
|     compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11); | ||||
|  | ||||
|     // Dispatch threads | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     return; | ||||
|   } | ||||
| @@ -435,7 +435,7 @@ void strided_reduce_general_dispatch( | ||||
|     compute_encoder->setThreadgroupMemoryLength( | ||||
|         safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), | ||||
|         0); | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   } else { | ||||
|     // Allocate intermediate array to store reduction results from all thread | ||||
| @@ -470,7 +470,7 @@ void strided_reduce_general_dispatch( | ||||
|     compute_encoder->setThreadgroupMemoryLength( | ||||
|         safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16), | ||||
|         0); | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|     // Perform second pass of reductions | ||||
|     // Reduce results of threadgroups along y, z from first pass, that | ||||
| @@ -523,7 +523,7 @@ void strided_reduce_general_dispatch( | ||||
|     grid_dims = MTL::Size(n_threads, out.size(), 1); | ||||
|     group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|  | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [intermediates](MTL::CommandBuffer*) mutable { | ||||
| @@ -585,7 +585,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     compute_encoder->setComputePipelineState(kernel); | ||||
|     compute_encoder.set_output_array(out, 0); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|  | ||||
|   // Reduce | ||||
|   | ||||
| @@ -83,7 +83,7 @@ void RoPE::eval_gpu( | ||||
|   int dim2 = in.size() / mat_size; | ||||
|   auto group_dims = get_block_dims(dim0, dim1, dim2); | ||||
|   auto grid_dims = MTL::Size(dim0, dim1, dim2); | ||||
|   compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| } // namespace mlx::core::fast | ||||
|   | ||||
| @@ -99,7 +99,7 @@ void sdpa_metal( | ||||
|  | ||||
|   constexpr const uint tgroupMemorySize = 32768; | ||||
|   compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0); | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|  | ||||
|   { | ||||
|     auto kernel_accum = d.get_kernel(kname_reduce.str()); | ||||
| @@ -114,7 +114,7 @@ void sdpa_metal( | ||||
|     MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch); | ||||
|     MTL::Size group_dims_reduce = MTL::Size(128, 1, 1); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce); | ||||
|  | ||||
|     d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|         [temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); }); | ||||
|   | ||||
| @@ -77,7 +77,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|         static_cast<int>(kernel->maxTotalThreadsPerThreadgroup())); | ||||
|     MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1); | ||||
|     MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } else { | ||||
|     kname << "strided_scan_"; | ||||
|     if (reverse_) { | ||||
| @@ -119,7 +119,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|     int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x; | ||||
|     MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1); | ||||
|     MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|  | ||||
|   if (copies.size() > 0) { | ||||
|   | ||||
| @@ -85,7 +85,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) { | ||||
|         in.data_shared_ptr() == nullptr ? out : in, 0); | ||||
|     compute_encoder.set_output_array(out, 1); | ||||
|     compute_encoder->setBytes(&axis_size, sizeof(int), 2); | ||||
|     compute_encoder->dispatchThreads(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreads(grid_dims, group_dims); | ||||
|   } | ||||
|   d.get_command_buffer(s.index)->addCompletedHandler( | ||||
|       [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); | ||||
|   | ||||
| @@ -78,7 +78,7 @@ void single_block_sort( | ||||
|   MTL::Size group_dims = MTL::Size(bn, 1, 1); | ||||
|   MTL::Size grid_dims = MTL::Size(1, n_rows, 1); | ||||
|  | ||||
|   compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|   compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| template <bool ARGSORT> | ||||
| @@ -155,7 +155,7 @@ void multi_block_sort( | ||||
|     MTL::Size group_dims = MTL::Size(bn, 1, 1); | ||||
|     MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); | ||||
|  | ||||
|     compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|     compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|   } | ||||
|  | ||||
|   // Do merges | ||||
| @@ -190,7 +190,7 @@ void multi_block_sort( | ||||
|       MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1); | ||||
|       MTL::Size grid_dims = MTL::Size(1, n_rows, 1); | ||||
|  | ||||
|       compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|     } | ||||
|  | ||||
|     // Do merge | ||||
| @@ -214,7 +214,7 @@ void multi_block_sort( | ||||
|       MTL::Size group_dims = MTL::Size(bn, 1, 1); | ||||
|       MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); | ||||
|  | ||||
|       compute_encoder->dispatchThreadgroups(grid_dims, group_dims); | ||||
|       compute_encoder.dispatchThreadgroups(grid_dims, group_dims); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun