diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 321806720..158d3de6e 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -15,8 +15,8 @@ void copy_gpu_inplace( int64_t offset_out, CopyType ctype, const Stream& s, - const std::optional& dynamic_offset_in, - const std::optional& dynamic_offset_out) { + std::optional dynamic_offset_in, + std::optional dynamic_offset_out) { if (out.size() == 0) { return; } @@ -44,6 +44,16 @@ void copy_gpu_inplace( strides_vec[0]); } else { if (dynamic_offset_in || dynamic_offset_out) { + if (!dynamic_offset_in) { + dynamic_offset_in = array(0, int64); + encoder.add_temporary(*dynamic_offset_in); + } + if (!dynamic_offset_out) { + dynamic_offset_out = array(0, int64); + encoder.add_temporary(*dynamic_offset_out); + } + encoder.set_input_array(*dynamic_offset_in); + encoder.set_input_array(*dynamic_offset_out); copy_general_dynamic( encoder, ctype, @@ -54,8 +64,8 @@ void copy_gpu_inplace( shape_collapsed, strides_vec[0], strides_vec[1], - dynamic_offset_in ? *dynamic_offset_in : array(0, int64), - dynamic_offset_out ? *dynamic_offset_out : array(0, int64)); + *dynamic_offset_in, + *dynamic_offset_out); } else { copy_general( encoder, diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 829529609..1d867e063 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { args.append(src.ndim()); args.append_ndim(slice_sizes_); args.append(slice_size); - args.append(SmallVector(axes_.begin(), axes_.end())); + args.append(axes_); append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( @@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { args.append_ndim(out.shape()); args.append_ndim(out.strides()); args.append(out.ndim()); - args.append(SmallVector(axes_.begin(), axes_.end())); + args.append(axes_); append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index d919f9bc0..cc569690a 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -46,6 +46,11 @@ struct KernelArgs { append_ptr(std::get>(storage_.back()).data()); } + template + void append(const std::vector& vec) { + append(SmallVector(vec.begin(), vec.end())); + } + // Make sure the arg is copied to an array with size of NDIM. template void append_ndim(SmallVector vec) { diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index f9a594ab8..77c295665 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -24,8 +24,6 @@ namespace mlx::core { } NO_GPU(BlockMaskedMM) -NO_GPU(DynamicSlice) -NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) NO_GPU(GatherMM) NO_GPU(GatherQMM) diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index af67fbbdd..18cc14bbd 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -1,8 +1,11 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/slicing.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" +#include "mlx/dtype_utils.h" #include @@ -38,4 +41,71 @@ void concatenate_gpu( } } +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s) { + Dtype dtype = indices.dtype(); + int nidx = axes.size(); + + std::string module_name = + fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx); + std::string kernel_name = fmt::format( + "mlx::core::cu::compute_dynamic_offset<{}, {}>", + dtype_to_cuda_type(dtype), + nidx); + + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::string source = R"( + #include "mlx/backend/cuda/device/utils.cuh" + + namespace mlx::core::cu { + + template + __global__ void compute_dynamic_offset( + const T* indices, + int64_t* offset, + const __grid_constant__ Strides strides, + const __grid_constant__ cuda::std::array axes) { + int64_t acc = 0; + #pragma unroll + for (int i = 0; i < NIDX; ++i) { + acc += indices[i] * strides[axes[i]]; + } + *offset = acc; + } + + } // namespace mlx::core::cu + )"; + return std::make_tuple(false, std::move(source), std::vector{kernel_name}); + }); + + // Prepare output. + array offset({1}, int64, nullptr, {}); + bool donate = indices.is_donatable() && + (indices.data_size() * indices.itemsize()) >= offset.itemsize(); + if (donate) { + offset.copy_shared_buffer(indices); + } else { + offset.set_data(allocator::malloc(offset.itemsize())); + } + + auto& encoder = cu::get_command_encoder(s); + encoder.add_temporary(offset); + encoder.set_input_array(indices); + encoder.set_output_array(offset); + + cu::KernelArgs args; + args.append(indices); + args.append(offset); + args.append_ndim(strides); + args.append(axes); + + auto kernel = mod.get_kernel(kernel_name); + encoder.add_kernel_node(kernel, 1, 1, 0, args.args()); + + return offset; +} + } // namespace mlx::core diff --git a/mlx/backend/gpu/copy.h b/mlx/backend/gpu/copy.h index 274250202..6e6bc7978 100644 --- a/mlx/backend/gpu/copy.h +++ b/mlx/backend/gpu/copy.h @@ -20,8 +20,8 @@ void copy_gpu_inplace( int64_t o_offset, CopyType ctype, const Stream& s, - const std::optional& dynamic_i_offset = std::nullopt, - const std::optional& dynamic_o_offset = std::nullopt); + std::optional dynamic_i_offset = std::nullopt, + std::optional dynamic_o_offset = std::nullopt); void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype); diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 6017879a5..ee40799df 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -80,6 +80,74 @@ void Depends::eval_gpu( eval(inputs, outputs); } +void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("DynamicSlice::eval_gpu"); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& start = inputs[1]; + out.set_data(allocator::malloc(out.nbytes())); + + auto s = stream(); + auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); + copy_gpu_inplace( + /* const array& src = */ in, + /* array& dst = */ out, + /* const Shape& data_shape = */ out.shape(), + /* const Strides& i_strides = */ in.strides(), + /* const Strides& o_strides = */ out.strides(), + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ s, + /* std::optional dynamic_i_offset = */ std::move(in_offset), + /* std::optional dynamic_o_offset = */ std::nullopt); +} + +void DynamicSliceUpdate::eval_gpu( + const std::vector& inputs, + array& out) { + MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu"); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + auto& start_indices = inputs[2]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + // Copy or donate input to output + auto s = stream(); + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); + + auto out_offset = + compute_dynamic_offset(start_indices, out.strides(), axes_, s); + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out.strides(), + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ s, + /* std::optional dynamic_i_offset = */ std::nullopt, + /* std::optional dynamic_o_offset = */ std::move(out_offset)); +} + void ExpandDims::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("ExpandDims::eval_gpu"); eval(inputs, out); diff --git a/mlx/backend/gpu/slicing.h b/mlx/backend/gpu/slicing.h index 7c48214a4..596f7afa5 100644 --- a/mlx/backend/gpu/slicing.h +++ b/mlx/backend/gpu/slicing.h @@ -27,4 +27,10 @@ void pad_gpu( const Shape& low_pad_size, const Stream& s); +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s); + } // namespace mlx::core diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 915fc69fd..d9c568e52 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -20,8 +20,8 @@ void copy_gpu_inplace( int64_t out_offset, CopyType ctype, const Stream& s, - const std::optional& dynamic_i_offset /* = std::nullopt */, - const std::optional& dynamic_o_offset /* = std::nullopt */) { + std::optional dynamic_i_offset /* = std::nullopt */, + std::optional dynamic_o_offset /* = std::nullopt */) { if (out.size() == 0) { return; } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 2ac543ad8..92a9a4158 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -4,7 +4,6 @@ #include #include -#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" @@ -25,60 +24,6 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(step, 1); } -static array compute_dynamic_offset( - const array& indices, - const Strides& strides, - const std::vector& axes, - Stream s) { - auto& d = metal::device(s.device); - - // Kernel to compute offset here. - array offset({1}, int64, nullptr, {}); - bool donate = indices.is_donatable() && - (indices.data_size() * indices.itemsize()) >= offset.itemsize(); - if (donate) { - offset.copy_shared_buffer(indices); - } else { - offset.set_data(allocator::malloc(offset.itemsize())); - } - d.add_temporary(offset, s.index); - - auto dtype = indices.dtype(); - std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); - auto lib = d.get_library(lib_name, [dtype]() { - return fmt::format( - R"( - [[kernel]] void compute_dynamic_offset_{0}( - constant const {1}* indices [[buffer(0)]], - device int64_t& offset [[buffer(1)]], - constant const int64_t* strides [[buffer(2)]], - constant const int* axes [[buffer(3)]], - constant const int& n_axes [[buffer(4)]], - uint index [[thread_position_in_grid]]) {{ - int64_t acc = 0; - for (int i = 0; i < n_axes; ++i) {{ - acc += indices[i] * strides[axes[i]]; - }} - offset = acc; - }})", - type_to_name(dtype), - get_type_string(dtype)); - }); - auto kernel = d.get_kernel(lib_name, lib); - - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(indices, 0); - compute_encoder.set_output_array(offset, 1); - compute_encoder.set_vector_bytes(strides, 2); - compute_encoder.set_vector_bytes(axes, 3); - int n_axes = axes.size(); - compute_encoder.set_bytes(n_axes, 4); - MTL::Size dims = MTL::Size(1, 1, 1); - compute_encoder.dispatch_threads(dims, dims); - return offset; -} - void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); out.set_data(allocator::malloc(out.nbytes())); @@ -256,72 +201,6 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - auto& start = inputs[1]; - out.set_data(allocator::malloc(out.nbytes())); - auto s = stream(); - auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); - copy_gpu_inplace( - /* const array& src = */ in, - /* array& dst = */ out, - /* const Shape& data_shape = */ out.shape(), - /* const Strides& i_strides = */ in.strides(), - /* const Strides& o_strides = */ out.strides(), - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ s, - /* const std::optional& dynamic_i_offset = */ in_offset, - /* const std::optional& dynamic_o_offset = */ std::nullopt); -} - -void DynamicSliceUpdate::eval_gpu( - const std::vector& inputs, - array& out) { - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - - auto& in = inputs[0]; - auto& upd = inputs[1]; - auto& start_indices = inputs[2]; - - if (upd.size() == 0) { - out.copy_shared_buffer(in); - return; - } - - // Copy or donate input to output - auto s = stream(); - auto& d = metal::device(s.device); - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); - - auto out_offset = - compute_dynamic_offset(start_indices, out.strides(), axes_, s); - copy_gpu_inplace( - /* const array& src = */ upd, - /* array& dst = */ out, - /* const Shape& data_shape = */ upd.shape(), - /* const Strides& i_strides = */ upd.strides(), - /* const Strides& o_strides = */ out.strides(), - /* int64_t i_offset = */ 0, - /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::GeneralGeneral, - /* const Stream& s = */ s, - /* const std::optional& dynamic_i_offset = */ std::nullopt, - /* const std::optional& dynamic_o_offset = */ out_offset); -} - void QRF::eval_gpu( const std::vector& inputs, std::vector& outputs) { diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 3e1a8b541..1e14c35c8 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -2,9 +2,12 @@ #include +#include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" namespace mlx::core { @@ -39,4 +42,58 @@ void concatenate_gpu( } } +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s) { + auto& d = metal::device(s.device); + + // Kernel to compute offset here. + array offset({1}, int64, nullptr, {}); + bool donate = indices.is_donatable() && + (indices.data_size() * indices.itemsize()) >= offset.itemsize(); + if (donate) { + offset.copy_shared_buffer(indices); + } else { + offset.set_data(allocator::malloc(offset.itemsize())); + } + d.add_temporary(offset, s.index); + + auto dtype = indices.dtype(); + std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); + auto lib = d.get_library(lib_name, [dtype]() { + return fmt::format( + R"( + [[kernel]] void compute_dynamic_offset_{0}( + constant const {1}* indices [[buffer(0)]], + device int64_t& offset [[buffer(1)]], + constant const int64_t* strides [[buffer(2)]], + constant const int* axes [[buffer(3)]], + constant const int& n_axes [[buffer(4)]], + uint index [[thread_position_in_grid]]) {{ + int64_t acc = 0; + for (int i = 0; i < n_axes; ++i) {{ + acc += indices[i] * strides[axes[i]]; + }} + offset = acc; + }})", + type_to_name(dtype), + get_type_string(dtype)); + }); + auto kernel = d.get_kernel(lib_name, lib); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(indices, 0); + compute_encoder.set_output_array(offset, 1); + compute_encoder.set_vector_bytes(strides, 2); + compute_encoder.set_vector_bytes(axes, 3); + int n_axes = axes.size(); + compute_encoder.set_bytes(n_axes, 4); + MTL::Size dims = MTL::Size(1, 1, 1); + compute_encoder.dispatch_threads(dims, dims); + return offset; +} + } // namespace mlx::core diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 2cc4a6c17..af5bace9a 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -1,7 +1,6 @@ cuda_skip = { "TestLoad.test_load_f8_e4m3", "TestLayers.test_quantized_embedding", - "TestOps.test_dynamic_slicing", # Block masked matmul NYI "TestBlas.test_block_masked_matmul", # Gather matmul NYI