From 9f0d5c12fcdbfcdd650951eb374eb602fbdce28a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 8 Nov 2024 11:50:21 -0800 Subject: [PATCH] Fully wrap the command encoder (#1572) * fully wrap the command encoder * use consistent style + fix extensions --- docs/src/dev/extensions.rst | 16 +- examples/extensions/axpby/axpby.cpp | 16 +- mlx/backend/metal/allocator.cpp | 3 + mlx/backend/metal/binary.cpp | 22 +- mlx/backend/metal/compiled.cpp | 16 +- mlx/backend/metal/conv.cpp | 64 +++--- mlx/backend/metal/copy.cpp | 18 +- mlx/backend/metal/custom_kernel.cpp | 10 +- mlx/backend/metal/device.cpp | 8 +- mlx/backend/metal/device.h | 39 +++- mlx/backend/metal/fft.cpp | 26 +-- mlx/backend/metal/hadamard.cpp | 6 +- mlx/backend/metal/indexing.cpp | 60 +++--- mlx/backend/metal/kernels/rms_norm.metal | 17 +- mlx/backend/metal/matmul.cpp | 188 +++++++++--------- mlx/backend/metal/normalization.cpp | 45 ++--- mlx/backend/metal/primitives.cpp | 48 +++-- mlx/backend/metal/quantized.cpp | 64 +++--- mlx/backend/metal/reduce.cpp | 110 +++++----- mlx/backend/metal/rope.cpp | 20 +- .../metal/scaled_dot_product_attention.cpp | 27 ++- mlx/backend/metal/scan.cpp | 16 +- mlx/backend/metal/softmax.cpp | 6 +- mlx/backend/metal/sort.cpp | 57 +++--- mlx/backend/metal/ternary.cpp | 22 +- mlx/backend/metal/unary.cpp | 12 +- mlx/backend/metal/utils.h | 17 -- 27 files changed, 469 insertions(+), 484 deletions(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index ecb418468..196f8bf65 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -494,7 +494,7 @@ below. // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Kernel parameters are registered with buffer indices corresponding to // those in the kernel declaration at axpby.metal @@ -509,14 +509,14 @@ below. compute_encoder.set_output_array(out, 2); // Encode alpha and beta - compute_encoder->setBytes(&alpha_, sizeof(float), 3); - compute_encoder->setBytes(&beta_, sizeof(float), 4); + compute_encoder.set_bytes(alpha_, 3); + compute_encoder.set_bytes(beta_, 4); // Encode shape, strides and ndim - compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); - compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); - compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); - compute_encoder->setBytes(&ndim, sizeof(int), 8); + compute_encoder.set_vector_bytes(x.shape(), 5); + compute_encoder.set_vector_bytes(x.strides(), 6); + compute_encoder.set_bytes(y.strides(), 7); + compute_encoder.set_bytes(ndim, 8); // We launch 1 thread for each input and make sure that the number of // threads in any given threadgroup is not higher than the max allowed @@ -530,7 +530,7 @@ below. // Launch the grid with the given number of threads divided among // the given threadgroups - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } We can now call the :meth:`axpby` operation on both the CPU and the GPU! diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index c8ba7c239..07db2dd0c 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -257,7 +257,7 @@ void Axpby::eval_gpu( // Prepare to encode kernel auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Kernel parameters are registered with buffer indices corresponding to // those in the kernel declaration at axpby.metal @@ -272,15 +272,15 @@ void Axpby::eval_gpu( compute_encoder.set_output_array(out, 2); // Encode alpha and beta - compute_encoder->setBytes(&alpha_, sizeof(float), 3); - compute_encoder->setBytes(&beta_, sizeof(float), 4); + compute_encoder.set_bytes(alpha_, 3); + compute_encoder.set_bytes(beta_, 4); // Encode shape, strides and ndim if needed if (!contiguous_kernel) { - compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); - compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); - compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); - compute_encoder->setBytes(&ndim, sizeof(int), 8); + compute_encoder.set_vector_bytes(x.shape(), 5); + compute_encoder.set_vector_bytes(x.strides(), 6); + compute_encoder.set_bytes(y.strides(), 7); + compute_encoder.set_bytes(ndim, 8); } // We launch 1 thread for each input and make sure that the number of @@ -295,7 +295,7 @@ void Axpby::eval_gpu( // Launch the grid with the given number of threads divided among // the given threadgroups - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } #else // Metal is not available diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0453aaf01..cfbc82943 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -242,6 +242,9 @@ void MetalAllocator::clear_cache() { void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); + if (buf == nullptr) { + return; + } std::unique_lock lk(mutex_); residency_set_.erase(buf); active_memory_ -= buf->length(); diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index f70595e56..66a31c922 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -92,7 +92,7 @@ void binary_op_gpu_inplace( ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // - If a is donated it goes to the first output // - If b is donated it goes to the first output if a was not donated @@ -117,19 +117,15 @@ void binary_op_gpu_inplace( size_t rest = out.size() / (dim0 * dim1); if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++); - compute_encoder->setBytes( - strides_a.data(), ndim * sizeof(size_t), arg_idx++); - compute_encoder->setBytes( - strides_b.data(), ndim * sizeof(size_t), arg_idx++); - compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++); + compute_encoder.set_vector_bytes(shape, arg_idx++); + compute_encoder.set_vector_bytes(strides_a, arg_idx++); + compute_encoder.set_vector_bytes(strides_b, arg_idx++); + compute_encoder.set_bytes(ndim, arg_idx++); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } else { // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes( - strides_a.data(), ndim * sizeof(size_t), arg_idx++); - compute_encoder->setBytes( - strides_b.data(), ndim * sizeof(size_t), arg_idx++); + compute_encoder.set_vector_bytes(strides_a, arg_idx++); + compute_encoder.set_vector_bytes(strides_b, arg_idx++); } if (thread_group_size != 1024) { @@ -137,7 +133,7 @@ void binary_op_gpu_inplace( } auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads size_t nthreads = out.data_size(); @@ -147,7 +143,7 @@ void binary_op_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index af9b8c872..a5a8805fb 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -373,7 +373,7 @@ void Compiled::eval_gpu( } auto kernel = d.get_kernel(kernel_name, lib); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Put the inputs in int cnt = 0; @@ -394,8 +394,7 @@ void Compiled::eval_gpu( } } if (!in_strides.empty()) { - compute_encoder->setBytes( - in_strides.data(), in_strides.size() * sizeof(size_t), cnt++); + compute_encoder.set_vector_bytes(in_strides, cnt++); } compiled_allocate_outputs( @@ -408,14 +407,13 @@ void Compiled::eval_gpu( // Put the output shape and strides in if (!contiguous) { - compute_encoder->setBytes( - strides[0].data(), strides[0].size() * sizeof(size_t), cnt++); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++); + compute_encoder.set_vector_bytes(strides[0], cnt++); + compute_encoder.set_vector_bytes(shape, cnt++); } // Put the number of dims in if it is dynamic if (dynamic) { - compute_encoder->setBytes(&ndim, sizeof(int), cnt++); + compute_encoder.set_bytes(ndim, cnt++); } // Launch the kernel @@ -427,7 +425,7 @@ void Compiled::eval_gpu( MTL::Size grid_dims = use_2d ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } else { size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; @@ -445,7 +443,7 @@ void Compiled::eval_gpu( } auto group_dims = get_block_dims(dim0, dim1, rest, pow2); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 0a2c62074..d5e715e80 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -44,12 +44,12 @@ void explicit_gemm_conv_ND_gpu( kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(in_unfolded, 1); - compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); + compute_encoder.set_bytes(conv_params, 2); // Launch unfolding kernel int tgp_x = std::min(conv_params.C, 64); @@ -60,7 +60,7 @@ void explicit_gemm_conv_ND_gpu( MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); // Reshape weight std::vector wt_reshape{implicit_K, implicit_N}; @@ -122,12 +122,12 @@ void explicit_gemm_conv_group_ND_gpu( << N; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(in_unfolded, 1); - compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); + compute_encoder.set_bytes(conv_params, 2); // Launch unfolding kernel int tgp_x = std::min(conv_params.C, 64); @@ -138,7 +138,7 @@ void explicit_gemm_conv_group_ND_gpu( MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); // Transpose kernel weights so that we can slice them by contiguous chunks // of channel groups. @@ -237,7 +237,7 @@ void slow_conv_2D_gpu( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; @@ -252,8 +252,8 @@ void slow_conv_2D_gpu( compute_encoder.set_input_array(wt, 1); compute_encoder.set_output_array(out, 2); - compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.set_bytes(conv_params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void implicit_gemm_conv_2D_gpu( @@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu( wn, n_channel_specialization, small_filter); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions int tile = 1 << swizzle_log; @@ -368,11 +368,11 @@ void implicit_gemm_conv_2D_gpu( compute_encoder.set_output_array(out, 2); // Encode params - compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); - compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); + compute_encoder.set_bytes(conv_params, 3); + compute_encoder.set_bytes(gemm_params, 4); // Launch kernel - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void implicit_gemm_conv_2D_general_gpu( @@ -506,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu( auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions int tile = 1 << swizzle_log; @@ -523,17 +523,15 @@ void implicit_gemm_conv_2D_general_gpu( compute_encoder.set_output_array(out, 2); // Encode params - compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); - compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); - compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5); + compute_encoder.set_bytes(conv_params, 3); + compute_encoder.set_bytes(gemm_params, 4); + compute_encoder.set_bytes(jump_params, 5); - compute_encoder->setBytes( - base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6); - compute_encoder->setBytes( - base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7); + compute_encoder.set_vector_bytes(base_h, 6); + compute_encoder.set_vector_bytes(base_w, 7); // Launch kernel - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void winograd_conv_2D_gpu( @@ -622,18 +620,18 @@ void winograd_conv_2D_gpu( << bc; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(wt, 0); compute_encoder.set_output_array(filt_wg, 1); - compute_encoder->setBytes(&C_c, sizeof(int), 2); - compute_encoder->setBytes(&O_c, sizeof(int), 3); + compute_encoder.set_bytes(C_c, 2); + compute_encoder.set_bytes(O_c, 3); MTL::Size group_dims = MTL::Size(32, bo, 1); MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do input transform @@ -650,18 +648,17 @@ void winograd_conv_2D_gpu( << bc; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_output_array(inp_wg, 1); - compute_encoder->setBytes( - &conv_params_updated, sizeof(MLXConvParams<2>), 2); + compute_encoder.set_bytes(conv_params_updated, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do batched gemm @@ -698,18 +695,17 @@ void winograd_conv_2D_gpu( << bc; auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes( - &conv_params_updated, sizeof(MLXConvParams<2>), 2); + compute_encoder.set_bytes(conv_params_updated, 2); MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 897eadb1c..f2f31cd1f 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -111,7 +111,7 @@ void copy_gpu_inplace( auto kernel = get_copy_kernel(d, kernel_name, in, out); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); bool donate_in = in.data_shared_ptr() == nullptr; inp_offset *= size_of(in.dtype()); @@ -125,11 +125,11 @@ void copy_gpu_inplace( std::vector strides_in{strides_in_.begin(), strides_in_.end()}; std::vector strides_out{strides_out_.begin(), strides_out_.end()}; if (ndim > 3) { - set_vector_bytes(compute_encoder, shape, ndim, 2); + compute_encoder.set_vector_bytes(shape, ndim, 2); } - set_vector_bytes(compute_encoder, strides_in, ndim, 3); + compute_encoder.set_vector_bytes(strides_in, ndim, 3); if (ctype == CopyType::GeneralGeneral) { - set_vector_bytes(compute_encoder, strides_out, ndim, 4); + compute_encoder.set_vector_bytes(strides_out, ndim, 4); } int dim0 = ndim > 0 ? shape[ndim - 1] : 1; @@ -141,7 +141,7 @@ void copy_gpu_inplace( int rest = data_size / (dim0 * dim1); if (ndim > MAX_COPY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 5); + compute_encoder.set_bytes(ndim, 5); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } @@ -152,7 +152,7 @@ void copy_gpu_inplace( auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } else { size_t nthreads = out.data_size(); if (thread_group_size > nthreads) { @@ -161,7 +161,7 @@ void copy_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -199,7 +199,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) { type_to_name(val) + type_to_name(out); auto kernel = get_copy_kernel(d, kernel_name, val, out); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); @@ -212,7 +212,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) { MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } // namespace mlx::core diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 06d7bf58c..8e0fb1173 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -43,7 +43,7 @@ void CustomKernel::eval_gpu( d.get_library(lib_name, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int index = 0; for (int i = 0; i < checked_inputs.size(); i++) { const array& in = checked_inputs[i]; @@ -53,15 +53,15 @@ void CustomKernel::eval_gpu( if (in.ndim() > 0) { int ndim = in.ndim(); if (shape_info.shape) { - set_vector_bytes(compute_encoder, in.shape(), ndim, index); + compute_encoder.set_vector_bytes(in.shape(), ndim, index); index++; } if (shape_info.strides) { - set_vector_bytes(compute_encoder, in.strides(), ndim, index); + compute_encoder.set_vector_bytes(in.strides(), ndim, index); index++; } if (shape_info.ndim) { - compute_encoder->setBytes(&ndim, sizeof(int), index); + compute_encoder.set_bytes(ndim, index); index++; } } @@ -75,7 +75,7 @@ void CustomKernel::eval_gpu( MTL::Size group_dims = MTL::Size(tx, ty, tz); const auto [gx, gy, gz] = grid_; MTL::Size grid_dims = MTL::Size(gx, gy, gz); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index d7b758e4d..be3e0bc83 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -171,14 +171,14 @@ void CommandEncoder::maybeInsertBarrier() { next_outputs_.clear(); } -void CommandEncoder::dispatchThreadgroups( +void CommandEncoder::dispatch_threadgroups( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); enc_->dispatchThreadgroups(grid_dims, group_dims); } -void CommandEncoder::dispatchThreads( +void CommandEncoder::dispatch_threads( MTL::Size grid_dims, MTL::Size group_dims) { maybeInsertBarrier(); @@ -298,7 +298,7 @@ void Device::end_encoding(int index) { if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { // If we've already waited on a fence, don't wait on it again. if (waiting_on.find(it->second) == waiting_on.end()) { - enc->waitForFence(it->second->fence); + enc.wait_for_fence(it->second->fence); waiting_on.insert(it->second); } } @@ -307,7 +307,7 @@ void Device::end_encoding(int index) { stream.outputs[out] = stream.fence; } } - enc->updateFence(stream.fence->fence); + enc.update_fence(stream.fence->fence); stream.buffer->addCompletedHandler( [&stream, waiting_on = std::move(waiting_on), diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bd366dc47..09397dc36 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -58,16 +58,43 @@ struct CommandEncoder { CommandEncoder& enc; }; - MTL::ComputeCommandEncoder* operator->() { - return enc_; - } - void set_input_array(const array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0); - void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); - void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); + void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); + void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); void maybeInsertBarrier(); + void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) { + enc_->setComputePipelineState(kernel); + } + + void wait_for_fence(MTL::Fence* fence) { + enc_->waitForFence(fence); + } + + void update_fence(MTL::Fence* fence) { + enc_->updateFence(fence); + } + + template + void set_vector_bytes(const std::vector& vec, size_t nelems, int idx) { + enc_->setBytes(vec.data(), nelems * sizeof(T), idx); + } + template + void set_vector_bytes(const std::vector& vec, int idx) { + return set_vector_bytes(vec, vec.size(), idx); + } + + template + void set_bytes(const T* v, int n, int idx) { + return enc_->setBytes(v, n * sizeof(T), idx); + } + + template + void set_bytes(const T& v, int idx) { + return enc_->setBytes(&v, sizeof(T), idx); + } + ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 12668eca7..e6da71fe1 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -699,7 +699,7 @@ void fft_op( auto kernel = get_fft_kernel(d, base_name, hash_name, func_consts, template_def); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_output_array(out, 1); @@ -711,9 +711,9 @@ void fft_op( compute_encoder.set_input_array(w_q, 2); // w_q compute_encoder.set_input_array(w_k, 3); // w_k - compute_encoder->setBytes(&n, sizeof(int), 4); - compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5); - compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); + compute_encoder.set_bytes(n, 4); + compute_encoder.set_bytes(plan.bluestein_n, 5); + compute_encoder.set_bytes(total_batch_size, 6); } else if (plan.rader_n > 1) { auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s); copies.push_back(b_q); @@ -723,22 +723,22 @@ void fft_op( compute_encoder.set_input_array(b_q, 2); compute_encoder.set_input_array(g_q, 3); compute_encoder.set_input_array(g_minus_q, 4); - compute_encoder->setBytes(&n, sizeof(int), 5); - compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); - compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7); + compute_encoder.set_bytes(n, 5); + compute_encoder.set_bytes(total_batch_size, 6); + compute_encoder.set_bytes(plan.rader_n, 7); } else if (four_step_params.required) { - compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2); - compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3); - compute_encoder->setBytes(&total_batch_size, sizeof(int), 4); + compute_encoder.set_bytes(four_step_params.n1, 2); + compute_encoder.set_bytes(four_step_params.n2, 3); + compute_encoder.set_bytes(total_batch_size, 4); } else { - compute_encoder->setBytes(&n, sizeof(int), 2); - compute_encoder->setBytes(&total_batch_size, sizeof(int), 3); + compute_encoder.set_bytes(n, 2); + compute_encoder.set_bytes(total_batch_size, 3); } auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); auto grid_dims = MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 83a17a7a3..7dcc761af 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kernel_name, lib); assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&scale, sizeof(float), 2); + compute_encoder.set_bytes(scale, 2); MTL::Size group_dims = MTL::Size(1, threads_per, 1); MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); }; if (m > 1) { diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 75a532346..b7cb0ab7a 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -87,7 +87,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kernel_name, lib); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); size_t slice_size = 1; for (auto s : slice_sizes_) { @@ -131,20 +131,20 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 1); // Set source info - set_vector_bytes(compute_encoder, src.shape(), 2); - set_vector_bytes(compute_encoder, src.strides(), 3); - compute_encoder->setBytes(&ndim, sizeof(size_t), 4); - set_vector_bytes(compute_encoder, slice_sizes_, 5); - set_vector_bytes(compute_encoder, axes_, 6); + compute_encoder.set_vector_bytes(src.shape(), 2); + compute_encoder.set_vector_bytes(src.strides(), 3); + compute_encoder.set_bytes(ndim, 4); + compute_encoder.set_vector_bytes(slice_sizes_, 5); + compute_encoder.set_vector_bytes(axes_, 6); // Set index info // // We don't need to check for empty idx_shapes because gather has a // idx_ndim == 0 specialization - set_vector_bytes(compute_encoder, idx_shapes, 7); - set_vector_bytes(compute_encoder, idx_strides, 8); - set_vector_bytes(compute_encoder, idx_contigs, 9); - compute_encoder->setBytes(&idx_ndim, sizeof(int), 10); + compute_encoder.set_vector_bytes(idx_shapes, 7); + compute_encoder.set_vector_bytes(idx_strides, 8); + compute_encoder.set_vector_bytes(idx_contigs, 9); + compute_encoder.set_bytes(idx_ndim, 10); // Set index buffers for (int i = 0; i < nidx; ++i) { @@ -152,7 +152,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { } // Launch grid - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -289,7 +289,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { size_t nthreads = upd.size(); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Set all the buffers compute_encoder.set_input_array(upd, 1); @@ -323,14 +323,14 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Need placeholders so Metal doesn't compalain int shape_ = 0; size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 3); - compute_encoder->setBytes(&stride_, sizeof(size_t), 4); + compute_encoder.set_bytes(shape_, 3); + compute_encoder.set_bytes(stride_, 4); } else { - set_vector_bytes(compute_encoder, upd.shape(), 3); - set_vector_bytes(compute_encoder, upd.strides(), 4); + compute_encoder.set_vector_bytes(upd.shape(), 3); + compute_encoder.set_vector_bytes(upd.strides(), 4); } - compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); - compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + compute_encoder.set_bytes(upd_ndim, 5); + compute_encoder.set_bytes(upd_size, 6); // Set output info size_t out_ndim = out.ndim(); @@ -338,14 +338,14 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Need placeholders so Metal doesn't compalain int shape_ = 0; size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 7); - compute_encoder->setBytes(&stride_, sizeof(size_t), 8); + compute_encoder.set_bytes(shape_, 7); + compute_encoder.set_bytes(stride_, 8); } else { - set_vector_bytes(compute_encoder, out.shape(), 7); - set_vector_bytes(compute_encoder, out.strides(), 8); + compute_encoder.set_vector_bytes(out.shape(), 7); + compute_encoder.set_vector_bytes(out.strides(), 8); } - compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); - compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + compute_encoder.set_bytes(out_ndim, 9); + compute_encoder.set_vector_bytes(axes_, 10); // Set index info if (idx_ndim == 0) { @@ -355,11 +355,11 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { idx_strides.push_back(0); idx_contigs.push_back(false); } - set_vector_bytes(compute_encoder, idx_shapes, 11); - set_vector_bytes(compute_encoder, idx_strides, 12); - set_vector_bytes(compute_encoder, idx_contigs, 13); - compute_encoder->setBytes(&idx_ndim, sizeof(int), 14); - compute_encoder->setBytes(&idx_size, sizeof(size_t), 15); + compute_encoder.set_vector_bytes(idx_shapes, 11); + compute_encoder.set_vector_bytes(idx_strides, 12); + compute_encoder.set_vector_bytes(idx_contigs, 13); + compute_encoder.set_bytes(idx_ndim, 14); + compute_encoder.set_bytes(idx_size, 15); // Set index buffers for (int i = 0; i < nidx; ++i) { @@ -375,7 +375,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads"); } MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/rms_norm.metal b/mlx/backend/metal/kernels/rms_norm.metal index c79bbae10..7d89dd052 100644 --- a/mlx/backend/metal/kernels/rms_norm.metal +++ b/mlx/backend/metal/kernels/rms_norm.metal @@ -17,12 +17,15 @@ template constant float& eps, constant uint& axis_size, constant uint& w_stride, - threadgroup float* local_inv_mean [[threadgroup(0)]], - threadgroup float* local_sums [[threadgroup(1)]], uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + + threadgroup float local_inv_mean[1]; + threadgroup float local_sums[SIMD_SIZE]; + float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; @@ -84,13 +87,15 @@ template constant float& eps, constant uint& axis_size, constant uint& w_stride, - threadgroup float* local_inv_mean [[threadgroup(0)]], - threadgroup float* local_sums [[threadgroup(1)]], uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + threadgroup float local_inv_mean[1]; + threadgroup float local_sums[SIMD_SIZE]; + float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; @@ -376,8 +381,6 @@ template constant float& eps, \ constant uint& axis_size, \ constant uint& w_stride, \ - threadgroup float* local_inv_mean [[threadgroup(0)]], \ - threadgroup float* local_sums [[threadgroup(1)]], \ uint gid [[thread_position_in_grid]], \ uint lid [[thread_position_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ @@ -407,8 +410,6 @@ template constant float& eps, \ constant uint& axis_size, \ constant uint& w_stride, \ - threadgroup float* local_inv_mean [[threadgroup(0)]], \ - threadgroup float* local_sums [[threadgroup(1)]], \ uint gid [[thread_position_in_grid]], \ uint lid [[thread_position_in_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \ diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 0614fadc7..af3e85ec8 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -249,7 +249,7 @@ void steel_matmul_regular( wm, wn); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; @@ -288,12 +288,12 @@ void steel_matmul_regular( compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); + compute_encoder.set_bytes(params, 4); - set_vector_bytes(compute_encoder, batch_shape, 6); - set_vector_bytes(compute_encoder, batch_strides, 7); + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Record copies d.add_temporaries(std::move(copies), s.index); @@ -390,7 +390,7 @@ void steel_matmul( wn, mn_aligned, k_aligned); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; @@ -416,34 +416,30 @@ void steel_matmul( compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(C_split, 2); - compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Do accum kernel { - auto c_split_buf = - static_cast(C_split.buffer().ptr()); - const class MTL::Resource* const resources[1] = {c_split_buf}; - compute_encoder->memoryBarrier(resources, 1); auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + type_to_name(C_split); auto kernel = get_steel_gemm_splitk_accum_kernel( d, kernel_name, C_split, out, false); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Set the arguments for the kernel compute_encoder.set_input_array(C_split, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); - compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); - compute_encoder->setBytes(&N, sizeof(int), 4); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); @@ -625,7 +621,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); @@ -635,16 +631,16 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); - compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); - compute_encoder->setBytes(&mat_ld, sizeof(int), 6); + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); - compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); - set_vector_bytes(compute_encoder, batch_shape, 10); - set_vector_bytes(compute_encoder, batch_strides_vec, 11); - set_vector_bytes(compute_encoder, batch_strides_mat, 12); + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); return; @@ -822,7 +818,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); @@ -833,23 +829,23 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(c, 2); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); - compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); - compute_encoder->setBytes(&mat_ld, sizeof(int), 6); + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); - compute_encoder->setBytes(&alpha_, sizeof(float), 7); - compute_encoder->setBytes(&beta_, sizeof(float), 8); + compute_encoder.set_bytes(alpha_, 7); + compute_encoder.set_bytes(beta_, 8); - compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); - set_vector_bytes(compute_encoder, batch_shape, 10); - set_vector_bytes(compute_encoder, batch_strides_vec, 11); - set_vector_bytes(compute_encoder, batch_strides_mat, 12); - set_vector_bytes(compute_encoder, C_batch_stride, 13); + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); + compute_encoder.set_vector_bytes(C_batch_stride, 13); int bias_stride = c.strides()[c.ndim() - 1]; - compute_encoder->setBytes(&bias_stride, sizeof(int), 14); + compute_encoder.set_bytes(bias_stride, 14); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); return; @@ -907,7 +903,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { mn_aligned, k_aligned); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; @@ -933,8 +929,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(C_split, 2); - compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.set_bytes(params, 3); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Do accum kernel { @@ -943,25 +939,25 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto kernel = get_steel_gemm_splitk_accum_kernel( d, kernel_name, C_split, out, true); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Set the arguments for the kernel compute_encoder.set_input_array(C_split, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); - compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); - compute_encoder->setBytes(&N, sizeof(int), 4); + compute_encoder.set_bytes(split_k_partitions, 2); + compute_encoder.set_bytes(split_k_partition_stride, 3); + compute_encoder.set_bytes(N, 4); compute_encoder.set_input_array(c, 5); - compute_encoder->setBytes(&ldc, sizeof(int), 6); - compute_encoder->setBytes(&fdc, sizeof(int), 7); - compute_encoder->setBytes(&alpha_, sizeof(float), 8); - compute_encoder->setBytes(&beta_, sizeof(float), 9); + compute_encoder.set_bytes(ldc, 6); + compute_encoder.set_bytes(fdc, 7); + compute_encoder.set_bytes(alpha_, 8); + compute_encoder.set_bytes(beta_, 9); // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); @@ -1032,7 +1028,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { wm, wn); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int tn = (N + bn - 1) / bn; int tm = (M + bm - 1) / bm; @@ -1083,13 +1079,13 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(c, 2); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4); - compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5); + compute_encoder.set_bytes(gemm_params, 4); + compute_encoder.set_bytes(params, 5); - set_vector_bytes(compute_encoder, batch_shape, 6); - set_vector_bytes(compute_encoder, batch_strides, 7); + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } @@ -1304,7 +1300,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { contiguous_kernel); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); @@ -1372,18 +1368,18 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); - compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); - compute_encoder->setBytes(&mat_ld, sizeof(int), 6); - compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); - set_vector_bytes(compute_encoder, batch_shape, 10); - set_vector_bytes(compute_encoder, batch_strides_vec, 11); - set_vector_bytes(compute_encoder, batch_strides_mat, 12); + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides_vec, 11); + compute_encoder.set_vector_bytes(batch_strides_mat, 12); - set_vector_bytes(compute_encoder, mask_strides, 23); - set_vector_bytes(compute_encoder, mask_batch_strides, 24); + compute_encoder.set_vector_bytes(mask_strides, 23); + compute_encoder.set_vector_bytes(mask_batch_strides, 24); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); return; @@ -1423,7 +1419,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { wn, mn_aligned, k_aligned); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; @@ -1486,14 +1482,14 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); + compute_encoder.set_bytes(params, 4); - set_vector_bytes(compute_encoder, batch_shape, 6); - set_vector_bytes(compute_encoder, batch_strides, 7); + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); - set_vector_bytes(compute_encoder, mask_strides, 13); + compute_encoder.set_vector_bytes(mask_strides, 13); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } @@ -1687,7 +1683,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; MTL::Size group_dims = MTL::Size(32, bn, bm); @@ -1697,28 +1693,28 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(vec, 1); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); - compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); - compute_encoder->setBytes(&mat_ld, sizeof(int), 6); + compute_encoder.set_bytes(in_vector_len, 4); + compute_encoder.set_bytes(out_vector_len, 5); + compute_encoder.set_bytes(mat_ld, 6); - compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); - set_vector_bytes(compute_encoder, batch_shape, 10); - set_vector_bytes(compute_encoder, batch_strides, 11); + compute_encoder.set_bytes(batch_ndim, 9); + compute_encoder.set_vector_bytes(batch_shape, 10); + compute_encoder.set_vector_bytes(batch_strides, 11); int batch_ndim_vec = batch_shape_vec.size(); - compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12); - set_vector_bytes(compute_encoder, batch_shape_vec, 13); - set_vector_bytes(compute_encoder, batch_strides_vec, 14); + compute_encoder.set_bytes(batch_ndim_vec, 12); + compute_encoder.set_vector_bytes(batch_shape_vec, 13); + compute_encoder.set_vector_bytes(batch_strides_vec, 14); int batch_ndim_mat = batch_shape_mat.size(); - compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15); - set_vector_bytes(compute_encoder, batch_shape_mat, 16); - set_vector_bytes(compute_encoder, batch_strides_mat, 17); + compute_encoder.set_bytes(batch_ndim_mat, 15); + compute_encoder.set_vector_bytes(batch_shape_mat, 16); + compute_encoder.set_vector_bytes(batch_strides_mat, 17); compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix)); compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix)); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); return; @@ -1788,7 +1784,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { wm, wn); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Use problem size to determine threadblock swizzle int tn = (N + bn - 1) / bn; @@ -1827,10 +1823,10 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); + compute_encoder.set_bytes(params, 4); - set_vector_bytes(compute_encoder, batch_shape, 6); - set_vector_bytes(compute_encoder, batch_strides, 7); + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); compute_encoder.set_input_array(lhs_indices, 10); compute_encoder.set_input_array(rhs_indices, 11); @@ -1845,11 +1841,11 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { operand_batch_ndim.push_back(0); - set_vector_bytes(compute_encoder, operand_shape, 13); - set_vector_bytes(compute_encoder, operand_strides, 14); - set_vector_bytes(compute_encoder, operand_batch_ndim, 15); + compute_encoder.set_vector_bytes(operand_shape, 13); + compute_encoder.set_vector_bytes(operand_strides, 14); + compute_encoder.set_vector_bytes(operand_batch_ndim, 15); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index cdab18368..7fa7e8646 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -78,18 +78,15 @@ void RMSNorm::eval_gpu( } uint32_t w_stride = w.strides()[0]; - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( x.data_shared_ptr() == nullptr ? out : x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_output_array(out, 2); - compute_encoder->setBytes(&eps_, sizeof(float), 3); - compute_encoder->setBytes(&axis_size, sizeof(int), 4); - compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5); - compute_encoder->setThreadgroupMemoryLength( - 16 * 8, 0); // minimum of 16 bytes - compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(eps_, 3); + compute_encoder.set_bytes(axis_size, 4); + compute_encoder.set_bytes(w_stride, 5); + compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); @@ -183,16 +180,16 @@ void RMSNormVJP::eval_gpu( } uint32_t w_stride = w.strides()[0]; - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gw_temp, 4); - compute_encoder->setBytes(&eps_, sizeof(float), 5); - compute_encoder->setBytes(&axis_size, sizeof(int), 6); - compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(eps_, 5); + compute_encoder.set_bytes(axis_size, 6); + compute_encoder.set_bytes(w_stride, 7); + compute_encoder.dispatch_threads(grid_dims, group_dims); } ReductionPlan plan( @@ -273,17 +270,17 @@ void LayerNorm::eval_gpu( uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( x.data_shared_ptr() == nullptr ? out : x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(b, 2); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(&eps_, sizeof(float), 4); - compute_encoder->setBytes(&axis_size, sizeof(int), 5); - compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6); - compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(eps_, 4); + compute_encoder.set_bytes(axis_size, 5); + compute_encoder.set_bytes(w_stride, 6); + compute_encoder.set_bytes(b_stride, 7); + compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); @@ -395,16 +392,16 @@ void LayerNormVJP::eval_gpu( } uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(x_in_gx ? gx : x, 0); compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); compute_encoder.set_output_array(gx, 3); compute_encoder.set_output_array(gw_temp, 4); - compute_encoder->setBytes(&eps_, sizeof(float), 5); - compute_encoder->setBytes(&axis_size, sizeof(int), 6); - compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(eps_, 5); + compute_encoder.set_bytes(axis_size, 6); + compute_encoder.set_bytes(w_stride, 7); + compute_encoder.dispatch_threads(grid_dims, group_dims); } if (gw.ndim() == 1 && gw.size() == axis_size) { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 67ec2949a..da176ffd1 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -17,10 +17,10 @@ namespace mlx::core { template -void arange_set_scalars(T start, T next, CommandEncoder& enc) { - enc->setBytes(&start, sizeof(T), 0); +void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { + enc.set_bytes(start, 0); T step = next - start; - enc->setBytes(&step, sizeof(T), 1); + enc.set_bytes(step, 1); } void Arange::eval_gpu(const std::vector& inputs, array& out) { @@ -37,7 +37,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); switch (out.dtype()) { case bool_: // unsupported @@ -80,7 +80,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { } compute_encoder.set_output_array(out, 2); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { @@ -129,25 +129,25 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { size_t n_threads = out.size() * thread_group_size; MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); if (ndim == 0) { // Pass place holders so metal doesn't complain int shape_ = 0; size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 2); - compute_encoder->setBytes(&stride_, sizeof(size_t), 3); - compute_encoder->setBytes(&stride_, sizeof(size_t), 4); + compute_encoder.set_bytes(shape_, 2); + compute_encoder.set_bytes(stride_, 3); + compute_encoder.set_bytes(stride_, 4); } else { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); - compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3); - compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4); + compute_encoder.set_vector_bytes(shape, 2); + compute_encoder.set_vector_bytes(in_strides, 3); + compute_encoder.set_vector_bytes(out_strides, 4); } - compute_encoder->setBytes(&ndim, sizeof(size_t), 5); - compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6); - compute_encoder->setBytes(&axis_size, sizeof(size_t), 7); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(ndim, 5); + compute_encoder.set_bytes(axis_stride, 6); + compute_encoder.set_bytes(axis_size, 7); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -275,22 +275,20 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); MTL::Size group_dims = MTL::Size(1, thread_group_size, 1); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(keys, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&odd, sizeof(bool), 2); - compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3); + compute_encoder.set_bytes(odd, 2); + compute_encoder.set_bytes(bytes_per_key, 3); if (!keys.flags().row_contiguous) { int ndim = keys.ndim(); - compute_encoder->setBytes(&ndim, sizeof(int), 4); - compute_encoder->setBytes( - keys.shape().data(), keys.ndim() * sizeof(int), 5); - compute_encoder->setBytes( - keys.strides().data(), keys.ndim() * sizeof(size_t), 6); + compute_encoder.set_bytes(ndim, 4); + compute_encoder.set_vector_bytes(keys.shape(), 5); + compute_encoder.set_vector_bytes(keys.strides(), 6); } - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void Reshape::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 4a74f2925..d41502617 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -101,31 +101,31 @@ void launch_qmm( auto& d = metal::device(s.device); auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(out, 4); - compute_encoder->setBytes(&D, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); + compute_encoder.set_bytes(D, 5); + compute_encoder.set_bytes(O, 6); int offset = 7; if (matrix) { - compute_encoder->setBytes(&B, sizeof(int), 7); + compute_encoder.set_bytes(B, 7); offset += 1; } if (batched || gather) { - compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset); - set_vector_bytes(compute_encoder, x_shape, offset + 1); - set_vector_bytes(compute_encoder, x_strides, offset + 2); - compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3); - set_vector_bytes(compute_encoder, w_shape, offset + 4); - set_vector_bytes(compute_encoder, w_strides, offset + 5); - set_vector_bytes(compute_encoder, s_strides, offset + 6); - set_vector_bytes(compute_encoder, b_strides, offset + 7); + compute_encoder.set_bytes(x_batch_ndims, offset); + compute_encoder.set_vector_bytes(x_shape, offset + 1); + compute_encoder.set_vector_bytes(x_strides, offset + 2); + compute_encoder.set_bytes(w_batch_ndims, offset + 3); + compute_encoder.set_vector_bytes(w_shape, offset + 4); + compute_encoder.set_vector_bytes(w_strides, offset + 5); + compute_encoder.set_vector_bytes(s_strides, offset + 6); + compute_encoder.set_vector_bytes(b_strides, offset + 7); } if (gather) { auto& lhs_indices = inputs[4]; @@ -137,15 +137,15 @@ void launch_qmm( auto& lhs_strides = lhs_indices.strides(); auto& rhs_strides = rhs_indices.strides(); - compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8); - set_vector_bytes(compute_encoder, batch_shape, offset + 9); + compute_encoder.set_bytes(batch_ndims, offset + 8); + compute_encoder.set_vector_bytes(batch_shape, offset + 9); compute_encoder.set_input_array(lhs_indices, offset + 10); compute_encoder.set_input_array(rhs_indices, offset + 11); - set_vector_bytes(compute_encoder, lhs_strides, offset + 12); - set_vector_bytes(compute_encoder, rhs_strides, offset + 13); + compute_encoder.set_vector_bytes(lhs_strides, offset + 12); + compute_encoder.set_vector_bytes(rhs_strides, offset + 13); } - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } @@ -236,27 +236,27 @@ void qvm_split_k( // Encode and dispatch kernel auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(intermediate, 4); - compute_encoder->setBytes(&split_D, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); + compute_encoder.set_bytes(split_D, 5); + compute_encoder.set_bytes(O, 6); - compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7); - set_vector_bytes(compute_encoder, x_shape, 8); - set_vector_bytes(compute_encoder, x_strides, 9); - compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10); - set_vector_bytes(compute_encoder, w_shape, 11); - set_vector_bytes(compute_encoder, w_strides, 12); - set_vector_bytes(compute_encoder, s_strides, 13); - set_vector_bytes(compute_encoder, b_strides, 14); - compute_encoder->setBytes(&final_block_size, sizeof(int), 15); + compute_encoder.set_bytes(x_batch_ndims, 7); + compute_encoder.set_vector_bytes(x_shape, 8); + compute_encoder.set_vector_bytes(x_strides, 9); + compute_encoder.set_bytes(w_batch_ndims, 10); + compute_encoder.set_vector_bytes(w_shape, 11); + compute_encoder.set_vector_bytes(w_strides, 12); + compute_encoder.set_vector_bytes(s_strides, 13); + compute_encoder.set_vector_bytes(b_strides, 14); + compute_encoder.set_bytes(final_block_size, 15); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); int axis = intermediate.ndim() - 3; @@ -447,7 +447,7 @@ void fast::AffineQuantize::eval_gpu( auto template_def = get_template_definition( kname.str(), kernel_func, type_string, group_size_, bits_); auto kernel = get_quantized_kernel(d, kname.str(), template_def); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Treat uint32 as uint8 in kernel constexpr int uint8_per_uint32 = 4; @@ -471,7 +471,7 @@ void fast::AffineQuantize::eval_gpu( } MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); d.add_temporaries(std::move(copies), s.index); } diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 6adab0824..960b1898e 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -67,17 +67,14 @@ struct RowReduceArgs { strides.push_back(0); } - compute_encoder->setBytes(&row_size, sizeof(size_t), 2); - compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 3); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); - compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 5); - compute_encoder->setBytes(&ndim, sizeof(int), 6); - compute_encoder->setBytes( - reduce_shape.data(), reduce_shape.size() * sizeof(int), 7); - compute_encoder->setBytes( - reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8); - compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9); + compute_encoder.set_bytes(row_size, 2); + compute_encoder.set_bytes(non_row_reductions, 3); + compute_encoder.set_vector_bytes(shape, 4); + compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_bytes(ndim, 6); + compute_encoder.set_vector_bytes(reduce_shape, 7); + compute_encoder.set_vector_bytes(reduce_strides, 8); + compute_encoder.set_bytes(reduce_ndim, 9); if (reduce_ndim == 0) { reduce_shape.pop_back(); @@ -166,18 +163,15 @@ struct ColReduceArgs { strides.push_back(0); } - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); - compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); - compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); - compute_encoder->setBytes( - strides.data(), strides.size() * sizeof(size_t), 5); - compute_encoder->setBytes(&ndim, sizeof(int), 6); - compute_encoder->setBytes( - reduce_shape.data(), reduce_shape.size() * sizeof(int), 7); - compute_encoder->setBytes( - reduce_strides.data(), reduce_strides.size() * sizeof(size_t), 8); - compute_encoder->setBytes(&reduce_ndim, sizeof(int), 9); - compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 10); + compute_encoder.set_bytes(reduction_size, 2); + compute_encoder.set_bytes(reduction_stride, 3); + compute_encoder.set_vector_bytes(shape, 4); + compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_bytes(ndim, 6); + compute_encoder.set_vector_bytes(reduce_shape, 7); + compute_encoder.set_vector_bytes(reduce_strides, 8); + compute_encoder.set_bytes(reduce_ndim, 9); + compute_encoder.set_bytes(non_col_reductions, 10); if (reduce_ndim == 0) { reduce_shape.pop_back(); @@ -256,9 +250,9 @@ void init_reduce( thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_output_array(out, 0); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void all_reduce_dispatch( @@ -273,7 +267,7 @@ void all_reduce_dispatch( const std::string func_name = "all_reduce"; kname << func_name << "_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); size_t in_size = in.size(); @@ -285,9 +279,9 @@ void all_reduce_dispatch( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&in_size, sizeof(size_t), 2); - compute_encoder->setBytes(&in_size, sizeof(size_t), 3); - compute_encoder.dispatchThreads(grid_dims, grid_dims); + compute_encoder.set_bytes(in_size, 2); + compute_encoder.set_bytes(in_size, 3); + compute_encoder.dispatch_threads(grid_dims, grid_dims); } // We need multiple threadgroups so we 'll do it in 2 passes. @@ -319,24 +313,24 @@ void all_reduce_dispatch( MTL::Size group_dims(threadgroup_size, 1, 1); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); - compute_encoder->setBytes(&in_size, sizeof(size_t), 2); - compute_encoder->setBytes(&row_size, sizeof(size_t), 3); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(in_size, 2); + compute_encoder.set_bytes(row_size, 3); + compute_encoder.dispatch_threads(grid_dims, group_dims); // 2nd pass std::ostringstream kname_2nd_pass; kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate); auto kernel_2nd_pass = get_reduce_kernel( d, kname_2nd_pass.str(), func_name, op_name, intermediate, out); - compute_encoder->setComputePipelineState(kernel_2nd_pass); + compute_encoder.set_compute_pipeline_state(kernel_2nd_pass); size_t intermediate_size = n_rows; grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2); - compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 3); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(intermediate_size, 2); + compute_encoder.set_bytes(intermediate_size, 3); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -355,7 +349,7 @@ void row_reduce_small( kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid dims MTL::Size grid_dims; @@ -375,7 +369,7 @@ void row_reduce_small( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void row_reduce_simple( @@ -391,7 +385,7 @@ void row_reduce_simple( const std::string func_name = "row_reduce_simple"; kname << func_name << "_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid dims size_t row_size = args.row_size; @@ -410,9 +404,9 @@ void row_reduce_simple( // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&row_size, sizeof(size_t), 2); - compute_encoder->setBytes(&out_size, sizeof(size_t), 3); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(row_size, 2); + compute_encoder.set_bytes(out_size, 3); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void row_reduce_looped( @@ -430,7 +424,7 @@ void row_reduce_looped( kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Figure out the grid auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides()); @@ -443,7 +437,7 @@ void row_reduce_looped( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void row_reduce_general_dispatch( @@ -495,7 +489,7 @@ void strided_reduce_small( kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); const int n_reads = 4; size_t reduction_stride_blocks = @@ -517,7 +511,7 @@ void strided_reduce_small( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void strided_reduce_longcolumn( @@ -568,14 +562,14 @@ void strided_reduce_longcolumn( kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); args.encode(compute_encoder); - compute_encoder->setBytes(&out_size, sizeof(size_t), 11); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.set_bytes(out_size, 11); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); // Make the 2nd pass arguments and grid_dims ColReduceArgs second_args(intermediate); @@ -599,12 +593,12 @@ void strided_reduce_longcolumn( 1, 32, 32); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); second_args.encode(compute_encoder); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void strided_reduce_looped( @@ -639,13 +633,13 @@ void strided_reduce_looped( << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); args.encode(compute_encoder); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void strided_reduce_2pass( @@ -692,14 +686,14 @@ void strided_reduce_2pass( << op_name << type_to_name(in); auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Launch compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(intermediate, 1); args.encode(compute_encoder); - compute_encoder->setBytes(&out_size, sizeof(size_t), 11); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(out_size, 11); + compute_encoder.dispatch_threads(grid_dims, group_dims); // Make the 2nd pass arguments and grid_dims ColReduceArgs second_args(intermediate); @@ -721,12 +715,12 @@ void strided_reduce_2pass( 1, 32, 32); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(intermediate, 0); compute_encoder.set_output_array(out, 1); second_args.encode(compute_encoder); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } void strided_reduce_general_dispatch( diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index fc6aa347c..195d29c2e 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -75,24 +75,24 @@ void RoPE::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); float base = std::log2(base_); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&offset_, sizeof(int), 2); - compute_encoder->setBytes(&scale_, sizeof(float), 3); + compute_encoder.set_bytes(offset_, 2); + compute_encoder.set_bytes(scale_, 3); size_t n_batch = in.size() / mat_size; MTL::Size group_dims; MTL::Size grid_dims; if (single) { - compute_encoder->setBytes(out_strides, sizeof(size_t), 4); + compute_encoder.set_bytes(out_strides, 1, 4); uint32_t dim0 = dims_ / 2; group_dims = get_block_dims(dim0, n_batch, 1); grid_dims = MTL::Size(dim0, n_batch, 1); } else { - compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4); - compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5); - compute_encoder->setBytes(&n_batch, sizeof(size_t), 6); + compute_encoder.set_bytes(strides, 3, 4); + compute_encoder.set_bytes(out_strides, 3, 5); + compute_encoder.set_bytes(n_batch, 6); uint32_t dim0 = dims_ / 2; uint32_t dim1 = in.shape(-2); uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread; @@ -104,11 +104,11 @@ void RoPE::eval_gpu( auto& freqs = inputs[1]; compute_encoder.set_input_array(freqs, 10); auto freq_stride = freqs.strides()[0]; - compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11); + compute_encoder.set_bytes(freq_stride, 11); } else { - compute_encoder->setBytes(&base, sizeof(float), 10); + compute_encoder.set_bytes(base, 10); } - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 19af2a850..3071650a5 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -59,7 +59,7 @@ void sdpa_full_self_attention_metal( auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname_self_attention.str()); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); uint hidden_dim = q.shape(-1); uint qseq = q.shape(-2); @@ -129,17 +129,14 @@ void sdpa_full_self_attention_metal( compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(¶ms, sizeof(MLXFastAttentionParams), 4); - compute_encoder->setBytes( - batch_shape.data(), sizeof(int) * batch_shape.size(), 6); - - compute_encoder->setBytes( - batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); + compute_encoder.set_bytes(params, 4); + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); MTL::Size group_dims = MTL::Size(32, wm, wn); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void sdpa_vector( @@ -170,21 +167,21 @@ void sdpa_vector( // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_output_array(out, 3); - compute_encoder->setBytes(&gqa_factor, sizeof(int), 4); - compute_encoder->setBytes(&N, sizeof(int), 5); - compute_encoder->setBytes(&k_stride, sizeof(size_t), 6); - compute_encoder->setBytes(&v_stride, sizeof(size_t), 7); - compute_encoder->setBytes(&scale, sizeof(float), 8); + compute_encoder.set_bytes(gqa_factor, 4); + compute_encoder.set_bytes(N, 5); + compute_encoder.set_bytes(k_stride, 6); + compute_encoder.set_bytes(v_stride, 7); + compute_encoder.set_bytes(scale, 8); // Launch - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } } // namespace diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index ae9c6a66f..46c2a9bea 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -68,12 +68,12 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { if (contiguous) { auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis_); - compute_encoder->setBytes(&size, sizeof(size_t), 2); + compute_encoder.set_bytes(size, 2); // Compute the thread grid int n_reads = (in.itemsize() <= 4) ? 4 : 2; @@ -95,10 +95,10 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims( thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); MTL::Size group_dims(thread_group_size, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } else { auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); @@ -107,9 +107,9 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { int bm = 32; int bn = 32; size_t stride_blocks = (stride + bn - 1) / bn; - compute_encoder->setBytes(&size, sizeof(size_t), 2); - compute_encoder->setBytes(&stride, sizeof(size_t), 3); - compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4); + compute_encoder.set_bytes(size, 2); + compute_encoder.set_bytes(stride, 3); + compute_encoder.set_bytes(stride_blocks, 4); // Compute the thread grid int n_reads = (in.itemsize() <= 4) ? 4 : 2; @@ -125,7 +125,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims( thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height); MTL::Size group_dims(thread_group_size, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 5bb7e66a4..732c5d98c 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -81,12 +81,12 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { group_dims = MTL::Size(threadgroup_size, 1, 1); } - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&axis_size, sizeof(int), 2); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.set_bytes(axis_size, 2); + compute_encoder.dispatch_threads(grid_dims, group_dims); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 925a5ccd9..d0d28e20c 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -68,29 +68,29 @@ void single_block_sort( // Prepare command encoder auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); // Set inputs compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); - compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2); - compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3); - compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4); + compute_encoder.set_bytes(size_sorted_axis, 2); + compute_encoder.set_bytes(in_stride_sorted_axis, 3); + compute_encoder.set_bytes(out_stride_sorted_axis, 4); if (contiguous) { - compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5); - compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6); + compute_encoder.set_bytes(in_stride_segment_axis, 5); + compute_encoder.set_bytes(out_stride_segment_axis, 6); } else { - compute_encoder->setBytes(&nc_dim, sizeof(int), 5); - compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6); - compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7); - compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8); + compute_encoder.set_bytes(nc_dim, 5); + compute_encoder.set_vector_bytes(nc_shape, 6); + compute_encoder.set_vector_bytes(in_nc_str, 7); + compute_encoder.set_vector_bytes(out_nc_str, 8); } MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void multi_block_sort( @@ -152,22 +152,21 @@ void multi_block_sort( << type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn; auto kernel = get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(dev_vals_0, 1); compute_encoder.set_output_array(dev_idxs_0, 2); - compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); - compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4); - compute_encoder->setBytes(&nc_dim, sizeof(int), 5); - compute_encoder->setBytes( - nc_shape.data(), nc_shape.size() * sizeof(int), 6); - compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7); + compute_encoder.set_bytes(size_sorted_axis, 3); + compute_encoder.set_bytes(stride_sorted_axis, 4); + compute_encoder.set_bytes(nc_dim, 5); + compute_encoder.set_vector_bytes(nc_shape, 6); + compute_encoder.set_vector_bytes(nc_str, 7); MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do merges @@ -194,19 +193,19 @@ void multi_block_sort( auto kernel = get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_output_array(block_partitions, 0); compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_idxs_in, 2); - compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); - compute_encoder->setBytes(&merge_tiles, sizeof(int), 4); - compute_encoder->setBytes(&n_blocks, sizeof(int), 5); + compute_encoder.set_bytes(size_sorted_axis, 3); + compute_encoder.set_bytes(merge_tiles, 4); + compute_encoder.set_bytes(n_blocks, 5); MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } // Do merge @@ -217,21 +216,21 @@ void multi_block_sort( auto kernel = get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(block_partitions, 0); compute_encoder.set_input_array(dev_vals_in, 1); compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_output_array(dev_vals_out, 3); compute_encoder.set_output_array(dev_idxs_out, 4); - compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5); - compute_encoder->setBytes(&merge_tiles, sizeof(int), 6); - compute_encoder->setBytes(&n_blocks, sizeof(int), 7); + compute_encoder.set_bytes(size_sorted_axis, 5); + compute_encoder.set_bytes(merge_tiles, 6); + compute_encoder.set_bytes(n_blocks, 7); MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 0f82f9894..0d0fdc657 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -63,7 +63,7 @@ void ternary_op_gpu_inplace( auto kernel = get_ternary_kernel(d, kernel_name, out.dtype(), op); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); bool donate_a = a.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr; bool donate_c = c.data_shared_ptr() == nullptr; @@ -80,18 +80,18 @@ void ternary_op_gpu_inplace( size_t rest = out.size() / (dim0 * dim1); if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); - compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); + compute_encoder.set_vector_bytes(shape, 4); + compute_encoder.set_vector_bytes(strides_a, 5); + compute_encoder.set_vector_bytes(strides_b, 6); + compute_encoder.set_vector_bytes(strides_c, 7); - compute_encoder->setBytes(&ndim, sizeof(int), 8); + compute_encoder.set_bytes(ndim, 8); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } else { // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); + compute_encoder.set_vector_bytes(strides_a, 4); + compute_encoder.set_vector_bytes(strides_b, 5); + compute_encoder.set_vector_bytes(strides_c, 6); } if (thread_group_size != 1024) { @@ -99,7 +99,7 @@ void ternary_op_gpu_inplace( } MTL::Size group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads size_t nthreads = out.data_size(); @@ -109,7 +109,7 @@ void ternary_op_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 8f061a3b7..8d23d4192 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -49,7 +49,7 @@ void unary_op_gpu_inplace( auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array( in.data_shared_ptr() == nullptr ? out : in, 0); compute_encoder.set_output_array(out, 1); @@ -58,16 +58,16 @@ void unary_op_gpu_inplace( size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = out.size() / (dim0 * dim1); - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); - compute_encoder->setBytes(strides.data(), ndim * sizeof(size_t), 3); - compute_encoder->setBytes(&ndim, sizeof(int), 4); + compute_encoder.set_vector_bytes(shape, 2); + compute_encoder.set_vector_bytes(strides, 3); + compute_encoder.set_bytes(ndim, 4); if (thread_group_size != 1024) { throw std::runtime_error("[Metal::unary] Must use 1024 sized block"); } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } else { if (thread_group_size > nthreads) { thread_group_size = nthreads; @@ -75,7 +75,7 @@ void unary_op_gpu_inplace( MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) : MTL::Size(nthreads, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); + compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 509a7e651..f2a9c7b20 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -8,23 +8,6 @@ namespace mlx::core { -using metal::CommandEncoder; - -template -inline void set_vector_bytes( - CommandEncoder& enc, - const std::vector& vec, - size_t nelems, - int idx) { - enc->setBytes(vec.data(), nelems * sizeof(T), idx); -} - -template -inline void -set_vector_bytes(CommandEncoder& enc, const std::vector& vec, int idx) { - return set_vector_bytes(enc, vec, vec.size(), idx); -} - std::string type_to_name(const array& a); // Compute the thread block dimensions which fit the given