diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index b7327348a..e3da32c4b 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -289,7 +289,7 @@ void Compiled::eval_gpu( } } auto kernel = d.get_kernel(kernel_name, lib); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); // Put the inputs in @@ -300,7 +300,7 @@ void Compiled::eval_gpu( continue; } auto& x = inputs[i]; - set_array_buffer(compute_encoder, x, cnt++); + compute_encoder.set_input_array(x, cnt++); if (!contiguous && !is_scalar(x)) { compute_encoder->setBytes( strides[stride_idx].data(), @@ -315,7 +315,7 @@ void Compiled::eval_gpu( // Put the outputs in for (auto& x : outputs) { - set_array_buffer(compute_encoder, x, cnt++); + compute_encoder.set_output_array(x, cnt++); } // Put the output shape and strides in diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index b954eb7e3..f2b3553f7 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -41,12 +41,12 @@ void explicit_gemm_conv_ND_gpu( // Prepare unfolding kernel std::ostringstream kname; kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, in_unfolded, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(in_unfolded, 1); compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); @@ -140,7 +140,7 @@ void slow_conv_2D_gpu( << "_tm" << tm << "_tn" << tn; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -153,9 +153,9 @@ void slow_conv_2D_gpu( MTL::Size group_dims = MTL::Size(bm, bn, 1); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, wt, 1); - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(wt, 1); + compute_encoder.set_output_array(out, 2); compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder->dispatchThreadgroups(grid_dims, group_dims); @@ -241,7 +241,7 @@ void implicit_gemm_conv_2D_gpu( << "_filter_" << (small_filter ? 's' : 'l'); // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -254,9 +254,9 @@ void implicit_gemm_conv_2D_gpu( MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1); // Encode arrays - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, wt, 1); - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(wt, 1); + compute_encoder.set_output_array(out, 2); // Encode params compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); @@ -394,7 +394,7 @@ void implicit_gemm_conv_2D_general_gpu( << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -408,9 +408,9 @@ void implicit_gemm_conv_2D_general_gpu( MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); // Encode arrays - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, wt, 1); - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(wt, 1); + compute_encoder.set_output_array(out, 2); // Encode params compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); @@ -511,12 +511,12 @@ void winograd_conv_2D_gpu( std::ostringstream kname; kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" << bc; - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, wt, 0); - set_array_buffer(compute_encoder, filt_wg, 1); + compute_encoder.set_input_array(wt, 0); + compute_encoder.set_output_array(filt_wg, 1); compute_encoder->setBytes(&C_c, sizeof(int), 2); compute_encoder->setBytes(&O_c, sizeof(int), 3); @@ -539,12 +539,12 @@ void winograd_conv_2D_gpu( std::ostringstream kname; kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" << bc; - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in_padded, 0); - set_array_buffer(compute_encoder, inp_wg, 1); + compute_encoder.set_input_array(in_padded, 0); + compute_encoder.set_output_array(inp_wg, 1); compute_encoder->setBytes( &conv_params_updated, sizeof(MLXConvParams<2>), 2); @@ -587,12 +587,12 @@ void winograd_conv_2D_gpu( std::ostringstream kname; kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" << bc; - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, out_wg, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(out_wg, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes( &conv_params_updated, sizeof(MLXConvParams<2>), 2); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index cb5fe289a..fd4e920f6 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -83,15 +83,15 @@ void copy_gpu_inplace( kname << "_" << shape.size(); } auto kernel = d.get_kernel(kname.str()); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); bool donate_in = in.data_shared_ptr() == nullptr; inp_offset *= size_of(in.dtype()); out_offset *= size_of(out.dtype()); - set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0); - set_array_buffer(compute_encoder, out, out_offset, 1); + compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset); + compute_encoder.set_output_array(out, 1, out_offset); if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { int ndim = shape.size(); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index c814b70b9..844615284 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023-24 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -206,14 +206,15 @@ void Device::end_encoding(int index) { } } -MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) { +CommandEncoder& Device::get_command_encoder(int index) { auto eit = encoder_map_.find(index); if (eit == encoder_map_.end()) { auto cb = get_command_buffer(index); - auto compute_encoder = cb->computeCommandEncoder(); + auto compute_encoder = + cb->computeCommandEncoder(MTL::DispatchTypeConcurrent); // Increment ref count so the buffer is not garbage collected compute_encoder->retain(); - eit = encoder_map_.insert({index, compute_encoder}).first; + eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first; } return eit->second; } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 8312084ce..4fc43c164 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -1,4 +1,4 @@ -// Copyright © 2023-24 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -7,10 +7,12 @@ #include #include #include +#include #include #include +#include "mlx/array.h" #include "mlx/device.h" namespace fs = std::filesystem; @@ -34,6 +36,69 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) { using MTLFCList = std::vector>; +struct CommandEncoder { + CommandEncoder(MTL::ComputeCommandEncoder* enc) + : enc(enc), concurrent(false){}; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc) : enc(enc) { + enc.concurrent = true; + } + ~ConcurrentContext() { + enc.concurrent = false; + enc.outputs.insert( + enc.concurrent_outputs.begin(), enc.concurrent_outputs.end()); + enc.concurrent_outputs.clear(); + } + + private: + CommandEncoder& enc; + }; + + MTL::ComputeCommandEncoder* operator->() { + return enc; + } + + void set_input_array(const array& a, int idx, int offset = 0) { + auto r_buf = + static_cast(const_cast(a.buffer().ptr())); + if (auto it = outputs.find(r_buf); it != outputs.end()) { + // Insert a barrier + enc->memoryBarrier(&r_buf, 1); + + // Remove the output + outputs.erase(it); + } + auto a_buf = static_cast(a.buffer().ptr()); + auto base_offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + base_offset += offset; + enc->setBuffer(a_buf, base_offset, idx); + } + + void set_output_array(array& a, int idx, int offset = 0) { + // Add barriers before adding the output to the output set + set_input_array(a, idx, offset); + auto buf = static_cast(a.buffer().ptr()); + if (concurrent) { + concurrent_outputs.insert(buf); + } else { + outputs.insert(buf); + } + } + + ConcurrentContext start_concurrent() { + return ConcurrentContext(*this); + } + + private: + MTL::ComputeCommandEncoder* enc; + bool concurrent; + std::unordered_set outputs; + std::unordered_set concurrent_outputs; +}; + class Device { public: Device(); @@ -51,7 +116,7 @@ class Device { int get_command_buffer_ops(int index); void increment_command_buffer_ops(int index); void commit_command_buffer(int index); - MTL::ComputeCommandEncoder* get_command_encoder(int index); + CommandEncoder& get_command_encoder(int index); void end_encoding(int index); void register_library( @@ -132,7 +197,7 @@ class Device { MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; - std::unordered_map encoder_map_; + std::unordered_map encoder_map_; std::unordered_map kernel_map_; std::unordered_map library_map_; std::mutex mtx_; diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 67095a021..b40c9c8c9 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -49,7 +49,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { kname << "_" << idx_ndim; } - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -81,8 +81,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { } // Set all the buffers - set_array_buffer(compute_encoder, src, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(src, 0); + compute_encoder.set_output_array(out, 1); // Set source info compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2); @@ -103,7 +103,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { // Set index buffers for (int i = 1; i < nidx + 1; ++i) { - set_array_buffer(compute_encoder, inputs[i], 20 + i); + compute_encoder.set_input_array(inputs[i], 20 + i); } // Launch grid @@ -183,7 +183,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { } kname << "_" << nidx; - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); auto& upd = inputs.back(); @@ -192,8 +192,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); // Set all the buffers - set_array_buffer(compute_encoder, upd, 1); - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_input_array(upd, 1); + compute_encoder.set_output_array(out, 2); // Set update info uint upd_ndim = upd.ndim(); @@ -210,7 +210,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Set index buffers for (int i = 1; i < nidx + 1; ++i) { - set_array_buffer(compute_encoder, inputs[i], 20 + i); + compute_encoder.set_input_array(inputs[i], 20 + i); } // Launch grid @@ -280,7 +280,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Set index buffers for (int i = 1; i < nidx + 1; ++i) { - set_array_buffer(compute_encoder, inputs[i], 20 + i); + compute_encoder.set_input_array(inputs[i], 20 + i); } // Launch grid diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index d69ad5998..0ea89e51b 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -336,7 +336,7 @@ void steel_matmul( << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; // Encode and dispatch gemm kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -360,9 +360,9 @@ void steel_matmul( MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); - set_array_buffer(compute_encoder, C_split, 2); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); compute_encoder->dispatchThreadgroups(grid_dims, group_dims); @@ -380,8 +380,8 @@ void steel_matmul( compute_encoder->setComputePipelineState(kernel); // Set the arguments for the kernel - set_array_buffer(compute_encoder, C_split, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); compute_encoder->setBytes(&N, sizeof(int), 4); @@ -426,7 +426,7 @@ void steel_matmul( << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -467,9 +467,9 @@ void steel_matmul( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); // Launch kernel - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); - set_array_buffer(compute_encoder, out, 3); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(out, 3); compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); @@ -622,7 +622,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { kname << "_nc" << !contiguous_kernel << "_axpby0"; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -630,9 +630,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(bn, bm, 1); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - set_array_buffer(compute_encoder, mat, 0); - set_array_buffer(compute_encoder, vec, 1); - set_array_buffer(compute_encoder, out, 3); + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); @@ -834,7 +834,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { kname << "_nc" << !contiguous_kernel << "_axpby1"; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -842,10 +842,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(bn, bm, 1); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); - set_array_buffer(compute_encoder, mat, 0); - set_array_buffer(compute_encoder, vec, 1); - set_array_buffer(compute_encoder, c, 2); - set_array_buffer(compute_encoder, out, 3); + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_input_array(c, 2); + compute_encoder.set_output_array(out, 3); compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); @@ -907,7 +907,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; // Encode and dispatch gemm kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -931,9 +931,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); - set_array_buffer(compute_encoder, C_split, 2); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(C_split, 2); compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); compute_encoder->dispatchThreadgroups(grid_dims, group_dims); @@ -946,12 +946,12 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); // Set the arguments for the kernel - set_array_buffer(compute_encoder, C_split, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(C_split, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); compute_encoder->setBytes(&N, sizeof(int), 4); - set_array_buffer(compute_encoder, c, 5); + compute_encoder.set_input_array(c, 5); compute_encoder->setBytes(&ldc, sizeof(int), 6); compute_encoder->setBytes(&fdc, sizeof(int), 7); compute_encoder->setBytes(&alpha_, sizeof(float), 8); @@ -997,7 +997,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { << ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby"); // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -1045,10 +1045,10 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); // Launch kernel - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); - set_array_buffer(compute_encoder, c, 2); - set_array_buffer(compute_encoder, out, 3); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(c, 2); + compute_encoder.set_output_array(out, 3); compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4); compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5); diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index f57498036..c18e5c658 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -88,7 +88,6 @@ std::function make_task( if (!arr.is_tracer()) { arr.detach(); } - if (p) { metal::device(s.device).end_encoding(s.index); scheduler::notify_new_task(s); diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index cd7ad7eac..61728b5f9 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -57,7 +57,7 @@ void RMSNorm::eval_gpu( op_name += "_looped"; } op_name += type_to_name(out); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name); @@ -79,10 +79,10 @@ void RMSNorm::eval_gpu( uint32_t w_stride = w.strides()[0]; compute_encoder->setComputePipelineState(kernel); - set_array_buffer( - compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_input_array( + x.data_shared_ptr() == nullptr ? out : x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_output_array(out, 2); compute_encoder->setBytes(&eps_, sizeof(float), 3); compute_encoder->setBytes(&axis_size, sizeof(int), 4); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5); @@ -160,7 +160,7 @@ void RMSNormVJP::eval_gpu( op_name += "_looped"; } op_name += type_to_name(gx); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name); @@ -182,12 +182,11 @@ void RMSNormVJP::eval_gpu( uint32_t w_stride = w.strides()[0]; compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer( - compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); - set_array_buffer(compute_encoder, gx, 3); - set_array_buffer(compute_encoder, gw_temp, 4); + compute_encoder.set_input_array(x_in_gx ? gx : x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); + compute_encoder.set_output_array(gx, 3); + compute_encoder.set_output_array(gw_temp, 4); compute_encoder->setBytes(&eps_, sizeof(float), 5); compute_encoder->setBytes(&axis_size, sizeof(int), 6); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); @@ -251,7 +250,7 @@ void LayerNorm::eval_gpu( op_name += "_looped"; } op_name += type_to_name(out); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name); @@ -274,11 +273,11 @@ void LayerNorm::eval_gpu( uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; compute_encoder->setComputePipelineState(kernel); - set_array_buffer( - compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer(compute_encoder, b, 2); - set_array_buffer(compute_encoder, out, 3); + compute_encoder.set_input_array( + x.data_shared_ptr() == nullptr ? out : x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(b, 2); + compute_encoder.set_output_array(out, 3); compute_encoder->setBytes(&eps_, sizeof(float), 4); compute_encoder->setBytes(&axis_size, sizeof(int), 5); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6); @@ -350,7 +349,7 @@ void LayerNormVJP::eval_gpu( } // Finish with the gradient for b in case we had a b - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); if (gb.ndim() == 1 && gb.size() == axis_size) { ReductionPlan plan( ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); @@ -394,12 +393,11 @@ void LayerNormVJP::eval_gpu( uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer( - compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); - set_array_buffer(compute_encoder, gx, 3); - set_array_buffer(compute_encoder, gw_temp, 4); + compute_encoder.set_input_array(x_in_gx ? gx : x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2); + compute_encoder.set_output_array(gx, 3); + compute_encoder.set_output_array(gw_temp, 4); compute_encoder->setBytes(&eps_, sizeof(float), 5); compute_encoder->setBytes(&axis_size, sizeof(int), 6); compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index a1497a34f..9137ff12f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -68,18 +68,18 @@ void binary_op( auto& s = out.primitive().stream(); auto& d = metal::device(s.device); auto kernel = d.get_kernel(kname.str()); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); // - If a is donated it goes to the first output // - If b is donated it goes to the first output if a was not donated // otherwise it goes to the second output bool donate_a = a.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr; - set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0); - set_array_buffer( - compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); - set_array_buffer(compute_encoder, outputs[0], 2); - set_array_buffer(compute_encoder, outputs[1], 3); + compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0); + compute_encoder.set_input_array( + donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); + compute_encoder.set_output_array(outputs[0], 2); + compute_encoder.set_output_array(outputs[1], 3); if (bopt == BinaryOpType::General) { auto ndim = shape.size(); @@ -167,13 +167,13 @@ void binary_op( auto& s = out.primitive().stream(); auto& d = metal::device(s.device); auto kernel = d.get_kernel(kname.str()); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); bool donate_a = a.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr; - set_array_buffer(compute_encoder, donate_a ? out : a, 0); - set_array_buffer(compute_encoder, donate_b ? out : b, 1); - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_input_array(donate_a ? out : a, 0); + compute_encoder.set_input_array(donate_b ? out : b, 1); + compute_encoder.set_output_array(out, 2); if (bopt == BinaryOpType::General) { auto ndim = shape.size(); @@ -253,12 +253,12 @@ void ternary_op( auto& s = out.primitive().stream(); auto& d = metal::device(s.device); auto kernel = d.get_kernel(kname.str()); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); - set_array_buffer(compute_encoder, c, 2); - set_array_buffer(compute_encoder, out, 3); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(c, 2); + compute_encoder.set_output_array(out, 3); if (topt == TernaryOpType::General) { auto ndim = shape.size(); @@ -339,11 +339,11 @@ void unary_op( } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer( - compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array( + in.data_shared_ptr() == nullptr ? out : in, 0); + compute_encoder.set_output_array(out, 1); if (!contig) { compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); compute_encoder->setBytes( @@ -365,7 +365,7 @@ void Add::eval_gpu(const std::vector& inputs, array& out) { } template -void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) { +void arange_set_scalars(T start, T next, CommandEncoder& enc) { enc->setBytes(&start, sizeof(T), 0); T step = next - start; enc->setBytes(&step, sizeof(T), 1); @@ -384,7 +384,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size group_dims = MTL::Size( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); switch (out.dtype()) { @@ -427,7 +427,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Arange::eval_gpu] Does not support complex64"); } - set_array_buffer(compute_encoder, out, 2); + compute_encoder.set_output_array(out, 2); compute_encoder->dispatchThreads(grid_dims, group_dims); } @@ -487,7 +487,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { // ArgReduce int simd_size = 32; int n_reads = 4; - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name + type_to_name(in)); NS::UInteger thread_group_size = std::min( @@ -502,8 +502,8 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); if (ndim == 0) { // Pass place holders so metal doesn't complain int shape_ = 0; @@ -552,6 +552,9 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; + auto& d = metal::device(stream().device); + auto& compute_encoder = d.get_command_encoder(stream().index); + auto concurrent_ctx = compute_encoder.start_concurrent(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis_] * sizes[i]; @@ -791,10 +794,10 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, keys, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(keys, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&odd, sizeof(bool), 2); compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 58e42ed9b..b41ee68e2 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -48,7 +48,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { << bits_ << "_fast"; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -57,11 +57,11 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(bd, 2, 1); MTL::Size grid_dims = MTL::Size(1, O / bo, B); - set_array_buffer(compute_encoder, w, 0); - set_array_buffer(compute_encoder, scales, 1); - set_array_buffer(compute_encoder, biases, 2); - set_array_buffer(compute_encoder, x, 3); - set_array_buffer(compute_encoder, out, 4); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); compute_encoder->setBytes(&D, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); @@ -75,7 +75,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { << bits_; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -84,11 +84,11 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(bd, 2, 1); MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); - set_array_buffer(compute_encoder, w, 0); - set_array_buffer(compute_encoder, scales, 1); - set_array_buffer(compute_encoder, biases, 2); - set_array_buffer(compute_encoder, x, 3); - set_array_buffer(compute_encoder, out, 4); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); compute_encoder->setBytes(&D, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); @@ -102,7 +102,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0); // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -114,11 +114,11 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1); - set_array_buffer(compute_encoder, x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer(compute_encoder, scales, 2); - set_array_buffer(compute_encoder, biases, 3); - set_array_buffer(compute_encoder, out, 4); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_output_array(out, 4); compute_encoder->setBytes(&B, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); compute_encoder->setBytes(&D, sizeof(int), 7); @@ -133,7 +133,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { << bits_; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -142,11 +142,11 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(bd, bo, 1); MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); - set_array_buffer(compute_encoder, x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer(compute_encoder, scales, 2); - set_array_buffer(compute_encoder, biases, 3); - set_array_buffer(compute_encoder, out, 4); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_output_array(out, 4); compute_encoder->setBytes(&D, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); @@ -160,7 +160,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { << bits_; // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); @@ -179,11 +179,11 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error(msg.str()); } - set_array_buffer(compute_encoder, x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer(compute_encoder, scales, 2); - set_array_buffer(compute_encoder, biases, 3); - set_array_buffer(compute_encoder, out, 4); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_output_array(out, 4); compute_encoder->setBytes(&B, sizeof(int), 5); compute_encoder->setBytes(&O, sizeof(int), 6); compute_encoder->setBytes(&D, sizeof(int), 7); diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index a9b425ab6..fc0d1ca1a 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -35,7 +35,7 @@ void all_reduce_dispatch( const array& in, array& out, const std::string& op_name, - MTL::ComputeCommandEncoder* compute_encoder, + CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { Dtype out_dtype = out.dtype(); @@ -71,8 +71,8 @@ void all_reduce_dispatch( // Encode buffers and dispatch if (is_out_64b_int == false || n_thread_groups == 1) { - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&in_size, sizeof(size_t), 2); compute_encoder->dispatchThreads(grid_dims, group_dims); @@ -85,14 +85,14 @@ void all_reduce_dispatch( std::vector intermediates = {intermediate}; // First dispatch - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, intermediate, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(intermediate, 1); compute_encoder->setBytes(&in_size, sizeof(size_t), 2); compute_encoder->dispatchThreads(grid_dims, group_dims); // Second pass to reduce intermediate reduction results written to DRAM - set_array_buffer(compute_encoder, intermediate, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2); mod_in_size = (intermediate_size + n_reads - 1) / n_reads; @@ -123,7 +123,7 @@ void row_reduce_general_dispatch( const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, - MTL::ComputeCommandEncoder* compute_encoder, + CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { Dtype out_dtype = out.dtype(); @@ -208,8 +208,8 @@ void row_reduce_general_dispatch( // Dispatch kernel if (!is_out_64b_int || non_row_reductions == 1) { // Set the arguments for the kernel - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); @@ -230,8 +230,8 @@ void row_reduce_general_dispatch( std::vector intermediates = {intermediate}; // Set the arguments for the kernel - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, intermediate, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(intermediate, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); @@ -258,8 +258,8 @@ void row_reduce_general_dispatch( ndim = new_shape.size(); // Set the arguments for the kernel - set_array_buffer(compute_encoder, intermediate, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4); @@ -301,7 +301,7 @@ void strided_reduce_general_dispatch( const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, - MTL::ComputeCommandEncoder* compute_encoder, + CommandEncoder& compute_encoder, metal::Device& d, const Stream& s) { Dtype out_dtype = out.dtype(); @@ -349,8 +349,8 @@ void strided_reduce_general_dispatch( } // Encode arrays - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); compute_encoder->setBytes(&out_size, sizeof(size_t), 4); @@ -415,8 +415,8 @@ void strided_reduce_general_dispatch( if (is_out_64b_int == false) { // Set the arguments for the kernel - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); compute_encoder->setBytes(&out_size, sizeof(size_t), 4); @@ -450,8 +450,8 @@ void strided_reduce_general_dispatch( std::vector intermediates = {intermediate}; // Set the arguments for the kernel - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, intermediate, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(intermediate, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); compute_encoder->setBytes(&out_size, sizeof(size_t), 4); @@ -494,8 +494,8 @@ void strided_reduce_general_dispatch( "row_reduce_general_no_atomics_" + op_name + type_to_name(intermediate)); compute_encoder->setComputePipelineState(row_reduce_kernel); - set_array_buffer(compute_encoder, intermediate, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(intermediate, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&out_size, sizeof(size_t), 3); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4); @@ -573,7 +573,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Initialize output auto& s = stream(); auto& d = metal::device(s.device); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel("i" + op_name + type_to_name(out)); size_t nthreads = out.size(); @@ -584,7 +584,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, out, 0); + compute_encoder.set_output_array(out, 0); compute_encoder->dispatchThreads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/reduce.h b/mlx/backend/metal/reduce.h index b674d6ba4..a997d7e24 100644 --- a/mlx/backend/metal/reduce.h +++ b/mlx/backend/metal/reduce.h @@ -8,11 +8,13 @@ namespace mlx::core { +using metal::CommandEncoder; + void all_reduce_dispatch( const array& in, array& out, const std::string& op_name, - MTL::ComputeCommandEncoder* compute_encoder, + CommandEncoder& compute_encoder, metal::Device& d, const Stream& s); @@ -22,7 +24,7 @@ void row_reduce_general_dispatch( const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, - MTL::ComputeCommandEncoder* compute_encoder, + CommandEncoder& compute_encoder, metal::Device& d, const Stream& s); @@ -32,7 +34,7 @@ void strided_reduce_general_dispatch( const std::string& op_name, const ReductionPlan& plan, const std::vector& axes, - MTL::ComputeCommandEncoder* compute_encoder, + CommandEncoder& compute_encoder, metal::Device& d, const Stream& s); diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index fd7e7d55b..1151f8c43 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -66,12 +66,12 @@ void RoPE::eval_gpu( kname << "rope_" << (forward_ ? "" : "vjp_") << (traditional_ ? "traditional_" : "") << type_to_name(in); auto kernel = d.get_kernel(kname.str()); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); float base = std::log2(base_); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, donated ? out : in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(donated ? out : in, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2); compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3); compute_encoder->setBytes(&offset_, sizeof(int), 4); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 55292f092..5b5a68870 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -71,7 +71,7 @@ void sdpa_metal( std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups; kname_partials << kname_suffix; - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname_partials.str()); compute_encoder->setComputePipelineState(kernel); @@ -87,15 +87,15 @@ void sdpa_metal( MLXScaledDotProductAttentionParams params{ query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha}; - set_array_buffer(compute_encoder, q, 0); - set_array_buffer(compute_encoder, k, 1); - set_array_buffer(compute_encoder, v, 2); + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3); compute_encoder->setBytes( ¶ms, sizeof(MLXScaledDotProductAttentionParams), 4); - set_array_buffer(compute_encoder, o_partial, 5); - set_array_buffer(compute_encoder, p_lse, 6); - set_array_buffer(compute_encoder, p_rowmaxes, 7); + compute_encoder.set_input_array(o_partial, 5); + compute_encoder.set_input_array(p_lse, 6); + compute_encoder.set_input_array(p_rowmaxes, 7); constexpr const uint tgroupMemorySize = 32768; compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0); @@ -104,12 +104,12 @@ void sdpa_metal( { auto kernel_accum = d.get_kernel(kname_reduce.str()); compute_encoder->setComputePipelineState(kernel_accum); - set_array_buffer(compute_encoder, o_partial, 0); - set_array_buffer(compute_encoder, p_lse, 1); - set_array_buffer(compute_encoder, p_rowmaxes, 2); + compute_encoder.set_input_array(o_partial, 0); + compute_encoder.set_input_array(p_lse, 1); + compute_encoder.set_input_array(p_rowmaxes, 2); compute_encoder->setBytes( ¶ms, sizeof(MLXScaledDotProductAttentionParams), 3); - set_array_buffer(compute_encoder, out, 4); + compute_encoder.set_output_array(out, 4); MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch); MTL::Size group_dims_reduce = MTL::Size(128, 1, 1); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 92f1007e1..94757b1e7 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -52,10 +52,10 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { kname << type_to_name(in) << "_" << type_to_name(out); auto kernel = d.get_kernel(kname.str()); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis_); compute_encoder->setBytes(&size, sizeof(size_t), 2); @@ -101,10 +101,10 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { kname << type_to_name(in) << "_" << type_to_name(out); auto kernel = d.get_kernel(kname.str()); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); size_t size = in.shape(axis_); size_t stride = in.strides()[axis_]; compute_encoder->setBytes(&size, sizeof(size_t), 2); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 58b19141c..1fbc1e00c 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -60,7 +60,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { op_name += "precise_"; } op_name += type_to_name(out); - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); { auto kernel = d.get_kernel(op_name); @@ -81,9 +81,9 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { } compute_encoder->setComputePipelineState(kernel); - set_array_buffer( - compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array( + in.data_shared_ptr() == nullptr ? out : in, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder->dispatchThreads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index ad8c83e48..9f53779a0 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -57,13 +57,13 @@ void single_block_sort( } // Prepare command encoder - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); // Set inputs - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2); compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3); @@ -131,7 +131,7 @@ void multi_block_sort( dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions}; // Prepare command encoder - auto compute_encoder = d.get_command_encoder(s.index); + auto& compute_encoder = d.get_command_encoder(s.index); // Do blockwise sort { @@ -142,9 +142,9 @@ void multi_block_sort( auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, dev_vals_0, 1); - set_array_buffer(compute_encoder, dev_idxs_0, 2); + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(dev_vals_0, 1); + compute_encoder.set_output_array(dev_idxs_0, 2); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4); compute_encoder->setBytes(&nc_dim, sizeof(int), 5); @@ -181,9 +181,9 @@ void multi_block_sort( auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, block_partitions, 0); - set_array_buffer(compute_encoder, dev_vals_in, 1); - set_array_buffer(compute_encoder, dev_idxs_in, 2); + compute_encoder.set_output_array(block_partitions, 0); + compute_encoder.set_input_array(dev_vals_in, 1); + compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); compute_encoder->setBytes(&merge_tiles, sizeof(int), 4); @@ -202,11 +202,11 @@ void multi_block_sort( auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, block_partitions, 0); - set_array_buffer(compute_encoder, dev_vals_in, 1); - set_array_buffer(compute_encoder, dev_idxs_in, 2); - set_array_buffer(compute_encoder, dev_vals_out, 3); - set_array_buffer(compute_encoder, dev_idxs_out, 4); + compute_encoder.set_input_array(block_partitions, 0); + compute_encoder.set_input_array(dev_vals_in, 1); + compute_encoder.set_input_array(dev_idxs_in, 2); + compute_encoder.set_output_array(dev_vals_out, 3); + compute_encoder.set_output_array(dev_idxs_out, 4); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5); compute_encoder->setBytes(&merge_tiles, sizeof(int), 6); compute_encoder->setBytes(&n_blocks, sizeof(int), 7); diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 1b90aa6c8..0ec315dd5 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -10,29 +10,11 @@ namespace mlx::core { namespace { -inline void -set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) { - auto a_buf = static_cast(a.buffer().ptr()); - auto offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - enc->setBuffer(a_buf, offset, idx); -} - -inline void set_array_buffer( - MTL::ComputeCommandEncoder* enc, - const array& a, - int64_t offset, - int idx) { - auto a_buf = static_cast(a.buffer().ptr()); - auto base_offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - base_offset += offset; - enc->setBuffer(a_buf, base_offset, idx); -} +using metal::CommandEncoder; template inline void set_vector_bytes( - MTL::ComputeCommandEncoder* enc, + CommandEncoder& enc, const std::vector& vec, size_t nelems, int idx) { @@ -40,10 +22,8 @@ inline void set_vector_bytes( } template -inline void set_vector_bytes( - MTL::ComputeCommandEncoder* enc, - const std::vector& vec, - int idx) { +inline void +set_vector_bytes(CommandEncoder& enc, const std::vector& vec, int idx) { return set_vector_bytes(enc, vec, vec.size(), idx); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 004421e96..ceb6b291d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1050,19 +1050,30 @@ array transpose( for (auto& ax : axes) { ax = ax < 0 ? ax + a.ndim() : ax; } - std::set dims(axes.begin(), axes.end()); - if (dims.size() != axes.size()) { - throw std::invalid_argument("Repeat axes not allowed in transpose."); + if (axes.size() != a.ndim()) { + std::ostringstream msg; + msg << "[transpose] Recived " << axes.size() << " axes for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); } - if (dims.size() != a.ndim() || - a.ndim() > 0 && - (*dims.begin() != 0 || *dims.rbegin() != (a.ndim() - 1))) { - throw std::invalid_argument("Transpose axes don't match array dimensions."); + + // Check in bounds and for duplicates + std::vector shape(axes.size(), 0); + for (auto& ax : axes) { + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "[transpose] Invalid axis (" << ax << ") for array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (shape[ax] != 0) { + throw std::invalid_argument("[transpose] Repeat axes not allowed."); + } + shape[ax] = 1; } - std::vector shape; - shape.reserve(axes.size()); - for (auto ax : axes) { - shape.push_back(a.shape()[ax]); + + for (int i = 0; i < axes.size(); ++i) { + shape[i] = a.shape()[axes[i]]; } return array( std::move(shape), diff --git a/python/src/utils.h b/python/src/utils.h index 35ac53a52..bd29df2d6 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -33,10 +33,10 @@ inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { } inline array to_array_with_accessor(nb::object obj) { - if (nb::hasattr(obj, "__mlx_array__")) { - return nb::cast(obj.attr("__mlx_array__")()); - } else if (nb::isinstance(obj)) { + if (nb::isinstance(obj)) { return nb::cast(obj); + } else if (nb::hasattr(obj, "__mlx_array__")) { + return nb::cast(obj.attr("__mlx_array__")()); } else { std::ostringstream msg; msg << "Invalid type " << nb::type_name(obj.type()).c_str()