From 5746c0c65865cee399aed9aca3b0fe305b1d91e5 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 24 Aug 2025 09:13:45 +0900 Subject: [PATCH] Move DynamicSlice to gpu/primitives --- mlx/backend/gpu/primitives.cpp | 66 +++++++++++++++++ mlx/backend/gpu/slicing.h | 6 ++ mlx/backend/metal/primitives.cpp | 121 ------------------------------- mlx/backend/metal/slicing.cpp | 57 +++++++++++++++ 4 files changed, 129 insertions(+), 121 deletions(-) diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 6017879a5..98e88ef04 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -80,6 +80,72 @@ void Depends::eval_gpu( eval(inputs, outputs); } +void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& start = inputs[1]; + out.set_data(allocator::malloc(out.nbytes())); + + auto s = stream(); + auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); + copy_gpu_inplace( + /* const array& src = */ in, + /* array& dst = */ out, + /* const Shape& data_shape = */ out.shape(), + /* const Strides& i_strides = */ in.strides(), + /* const Strides& o_strides = */ out.strides(), + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ s, + /* const std::optional& dynamic_i_offset = */ in_offset, + /* const std::optional& dynamic_o_offset = */ std::nullopt); +} + +void DynamicSliceUpdate::eval_gpu( + const std::vector& inputs, + array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + auto& start_indices = inputs[2]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + // Copy or donate input to output + auto s = stream(); + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); + + auto out_offset = + compute_dynamic_offset(start_indices, out.strides(), axes_, s); + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out.strides(), + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ s, + /* const std::optional& dynamic_i_offset = */ std::nullopt, + /* const std::optional& dynamic_o_offset = */ out_offset); +} + void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); eval(inputs, out); diff --git a/mlx/backend/gpu/slicing.h b/mlx/backend/gpu/slicing.h index 7c48214a4..596f7afa5 100644 --- a/mlx/backend/gpu/slicing.h +++ b/mlx/backend/gpu/slicing.h @@ -27,4 +27,10 @@ void pad_gpu( const Shape& low_pad_size, const Stream& s); +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s); + } // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 2ac543ad8..92a9a4158 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -4,7 +4,6 @@ #include #include -#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" @@ -25,60 +24,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } -static array compute_dynamic_offset( - const array& indices, - const Strides& strides, - const std::vector& axes, - Stream s) { - auto& d = metal::device(s.device); - - // Kernel to compute offset here. - array offset({1}, int64, nullptr, {}); - bool donate = indices.is_donatable() && - (indices.data_size() * indices.itemsize()) >= offset.itemsize(); - if (donate) { - offset.copy_shared_buffer(indices); - } else { - offset.set_data(allocator::malloc(offset.itemsize())); - } - d.add_temporary(offset, s.index); - - auto dtype = indices.dtype(); - std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); - auto lib = d.get_library(lib_name, [dtype]() { - return fmt::format( - R"( - [[kernel]] void compute_dynamic_offset_{0}( - constant const {1}* indices [[buffer(0)]], - device int64_t& offset [[buffer(1)]], - constant const int64_t* strides [[buffer(2)]], - constant const int* axes [[buffer(3)]], - constant const int& n_axes [[buffer(4)]], - uint index [[thread_position_in_grid]]) {{ - int64_t acc = 0; - for (int i = 0; i < n_axes; ++i) {{ - acc += indices[i] * strides[axes[i]]; - }} - offset = acc; - }})", - type_to_name(dtype), - get_type_string(dtype)); - }); - auto kernel = d.get_kernel(lib_name, lib); - - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(indices, 0); - compute_encoder.set_output_array(offset, 1); - compute_encoder.set_vector_bytes(strides, 2); - compute_encoder.set_vector_bytes(axes, 3); - int n_axes = axes.size(); - compute_encoder.set_bytes(n_axes, 4); - MTL::Size dims = MTL::Size(1, 1, 1); - compute_encoder.dispatch_threads(dims, dims); - return offset; -} - void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); out.set_data(allocator::malloc(out.nbytes())); @@ -256,72 +201,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - auto& start = inputs[1]; - out.set_data(allocator::malloc(out.nbytes())); - auto s = stream(); - auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); - copy_gpu_inplace( - /* const array& src = */ in, - /* array& dst = */ out, - /* const Shape& data_shape = */ out.shape(), - /* const Strides& i_strides = */ in.strides(), - /* const Strides& o_strides = */ out.strides(), - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ s, - /* const std::optional& dynamic_i_offset = */ in_offset, - /* const std::optional& dynamic_o_offset = */ std::nullopt); -} - -void DynamicSliceUpdate::eval_gpu( - const std::vector& inputs, - array& out) { - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - auto& upd = inputs[1]; - auto& start_indices = inputs[2]; - - if (upd.size() == 0) { - out.copy_shared_buffer(in); - return; - } - - // Copy or donate input to output - auto s = stream(); - auto& d = metal::device(s.device); - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); - - auto out_offset = - compute_dynamic_offset(start_indices, out.strides(), axes_, s); - copy_gpu_inplace( - /* const array& src = */ upd, - /* array& dst = */ out, - /* const Shape& data_shape = */ upd.shape(), - /* const Strides& i_strides = */ upd.strides(), - /* const Strides& o_strides = */ out.strides(), - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ s, - /* const std::optional& dynamic_i_offset = */ std::nullopt, - /* const std::optional& dynamic_o_offset = */ out_offset); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 3e1a8b541..1e14c35c8 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -2,9 +2,12 @@ #include +#include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" namespace mlx::core { @@ -39,4 +42,58 @@ void concatenate_gpu( } } +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s) { + auto& d = metal::device(s.device); + + // Kernel to compute offset here. + array offset({1}, int64, nullptr, {}); + bool donate = indices.is_donatable() && + (indices.data_size() * indices.itemsize()) >= offset.itemsize(); + if (donate) { + offset.copy_shared_buffer(indices); + } else { + offset.set_data(allocator::malloc(offset.itemsize())); + } + d.add_temporary(offset, s.index); + + auto dtype = indices.dtype(); + std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); + auto lib = d.get_library(lib_name, [dtype]() { + return fmt::format( + R"( + [[kernel]] void compute_dynamic_offset_{0}( + constant const {1}* indices [[buffer(0)]], + device int64_t& offset [[buffer(1)]], + constant const int64_t* strides [[buffer(2)]], + constant const int* axes [[buffer(3)]], + constant const int& n_axes [[buffer(4)]], + uint index [[thread_position_in_grid]]) {{ + int64_t acc = 0; + for (int i = 0; i < n_axes; ++i) {{ + acc += indices[i] * strides[axes[i]]; + }} + offset = acc; + }})", + type_to_name(dtype), + get_type_string(dtype)); + }); + auto kernel = d.get_kernel(lib_name, lib); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(indices, 0); + compute_encoder.set_output_array(offset, 1); + compute_encoder.set_vector_bytes(strides, 2); + compute_encoder.set_vector_bytes(axes, 3); + int n_axes = axes.size(); + compute_encoder.set_bytes(n_axes, 4); + MTL::Size dims = MTL::Size(1, 1, 1); + compute_encoder.dispatch_threads(dims, dims); + return offset; +} + } // namespace mlx::core