diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index c9b0067c8a..752f4b54c5 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -211,12 +211,15 @@ void binary_op_gpu_inplace( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu:: - binary_g_nd; auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>, num_blocks, block_dims, a.data(), @@ -228,11 +231,9 @@ void binary_op_gpu_inplace( const_param(b_strides)); }); } else { - auto kernel = cu::binary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::binary_g, num_blocks, block_dims, a.data(), @@ -258,12 +259,7 @@ void binary_op_gpu_inplace( kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( - kernel, - out.data_size(), - out.shape(), - out.strides(), - large(), - N_READS); + out.data_size(), out.shape(), out.strides(), large(), N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 598924c098..a56f004687 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -227,16 +227,15 @@ void binary_two_op_gpu_inplace( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu::binary_two_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>; auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large()); + get_launch_args(out_a, large()); encoder.add_kernel_node( - kernel, + cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>, num_blocks, block_dims, a.data(), @@ -249,11 +248,10 @@ void binary_two_op_gpu_inplace( const_param(b_strides)); }); } else { - auto kernel = cu::binary_two_g; auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large()); + get_launch_args(out_a, large()); encoder.add_kernel_node( - kernel, + cu::binary_two_g, num_blocks, block_dims, a.data(), @@ -280,7 +278,6 @@ void binary_two_op_gpu_inplace( kernel = cu::binary_two_vv; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out_a.data_size(), out_a.shape(), out_a.strides(), diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 1aff47b89a..6eda2533f4 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -294,7 +294,7 @@ void Compiled::eval_gpu( auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = - get_launch_args(kernel, outputs[0], large, work_per_thread); + get_launch_args(outputs[0], large, work_per_thread); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 8ac0533f37..3ec7478be6 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -71,12 +71,7 @@ void copy_contiguous( kernel = cu::copy_v; } auto [num_blocks, block_dims] = get_launch_args( - kernel, - out.data_size(), - out.shape(), - out.strides(), - large(), - N_READS); + out.data_size(), out.shape(), out.strides(), large(), N_READS); encoder.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index e92160b952..b65a24e547 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -71,12 +71,10 @@ void copy_general( data_size *= s; if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) { - auto kernel = - cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args( - kernel, data_size, shape, out.strides(), large()); + auto [num_blocks, block_dims] = + get_launch_args(data_size, shape, out.strides(), large()); encoder.add_kernel_node( - kernel, + cu::copy_gg_nd, num_blocks, block_dims, in_ptr, @@ -87,11 +85,10 @@ void copy_general( const_param(strides_out)); }); } else { // ndim >= 4 - auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = get_launch_args( - kernel, data_size, shape, out.strides(), large()); + auto [num_blocks, block_dims] = + get_launch_args(data_size, shape, out.strides(), large()); encoder.add_kernel_node( - kernel, + cu::copy_gg, num_blocks, block_dims, in_ptr, diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 419dd73fb1..bafc82057f 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -74,12 +74,13 @@ void copy_general_dynamic( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu:: - copy_gg_dynamic_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::copy_gg_dynamic_nd< + InType, + OutType, + IdxT, + dims_constant()>, num_blocks, block_dims, in_ptr, @@ -92,11 +93,9 @@ void copy_general_dynamic( dynamic_offset_out.data()); }); } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::copy_gg_dynamic, num_blocks, block_dims, in_ptr, diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index c66f3a7778..052cf56c33 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -63,12 +63,9 @@ void copy_general_input( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::copy_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::copy_g_nd, num_blocks, block_dims, in_ptr, @@ -78,11 +75,9 @@ void copy_general_input( const_param(strides_in)); }); } else { // ndim >= 4 - auto kernel = cu::copy_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::copy_g, num_blocks, block_dims, in_ptr, diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 4b03a604ec..69a85f6acb 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -128,7 +128,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = get_launch_args(out, large); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } @@ -229,7 +229,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { } encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = get_launch_args(kernel, upd, large); + auto [num_blocks, block_dims] = get_launch_args(upd, large); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } @@ -317,7 +317,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { } encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + auto [num_blocks, block_dims] = get_launch_args(idx, large); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } @@ -421,7 +421,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { } encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = get_launch_args(kernel, idx, large); + auto [num_blocks, block_dims] = get_launch_args(idx, large); encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); } diff --git a/mlx/backend/cuda/kernel_utils.cu b/mlx/backend/cuda/kernel_utils.cu index 7b87aa5b05..9ac9a82dae 100644 --- a/mlx/backend/cuda/kernel_utils.cu +++ b/mlx/backend/cuda/kernel_utils.cu @@ -30,4 +30,25 @@ std::pair get_grid_and_block(int dim0, int dim1, int dim2) { return std::make_pair(dim3(gx, gy, gz), dim3(bx, by, bz)); } +std::tuple get_launch_args( + size_t size, + const Shape& shape, + const Strides& strides, + bool large, + int work_per_thread) { + size_t nthreads = cuda::ceil_div(size, work_per_thread); + uint block_dim = 1024; + if (block_dim > nthreads) { + block_dim = nthreads; + } + dim3 num_blocks; + if (large) { + num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); + num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); + } else { + num_blocks.x = cuda::ceil_div(nthreads, block_dim); + } + return std::make_tuple(num_blocks, block_dim); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index bf10de6497..fbbca0a060 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -122,37 +122,17 @@ std::pair get_grid_and_block(int dim0, int dim1, int dim2); // Get the num_blocks and block_dims that maximize occupancy for |kernel|, // assuming each thread handles |work_per_thread| elements of |arr|. -template -inline std::tuple get_launch_args( - T kernel, +std::tuple get_launch_args( size_t size, const Shape& shape, const Strides& strides, bool large, - int work_per_thread = 1) { - size_t nthreads = cuda::ceil_div(size, work_per_thread); - uint block_dim = 1024; - if (block_dim > nthreads) { - block_dim = nthreads; - } - dim3 num_blocks; - if (large) { - num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); - num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); - } else { - num_blocks.x = cuda::ceil_div(nthreads, block_dim); - } - return std::make_tuple(num_blocks, block_dim); -} + int work_per_thread = 1); -template -inline std::tuple get_launch_args( - T kernel, - const array& arr, - bool large, - int work_per_thread = 1) { +inline std::tuple +get_launch_args(const array& arr, bool large, int work_per_thread = 1) { return get_launch_args( - kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); + arr.size(), arr.shape(), arr.strides(), large, work_per_thread); } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized.cu b/mlx/backend/cuda/quantized.cu index 204dbd5472..5702fa5a9f 100644 --- a/mlx/backend/cuda/quantized.cu +++ b/mlx/backend/cuda/quantized.cu @@ -350,12 +350,10 @@ void fast::AffineQuantize::eval_gpu( dispatch_bits(bits_, [&](auto bits) { using DataType = cuda_type_t; if (dequantize_) { - auto kernel = - cu::affine_dequantize; auto [num_blocks, block_dims] = - get_launch_args(kernel, size, grid_shape, w.strides(), large); + get_launch_args(size, grid_shape, w.strides(), large); enc.add_kernel_node( - kernel, + cu::affine_dequantize, num_blocks, block_dims, w.data(), @@ -364,12 +362,10 @@ void fast::AffineQuantize::eval_gpu( out.data(), out.size()); } else { - auto kernel = - cu::affine_quantize; auto [num_blocks, block_dims] = - get_launch_args(kernel, size, grid_shape, w.strides(), large); + get_launch_args(size, grid_shape, w.strides(), large); enc.add_kernel_node( - kernel, + cu::affine_quantize, num_blocks, block_dims, w.data(), diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 6f9c6c7e4d..58d3fa119a 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -125,12 +125,9 @@ void ternary_op_gpu_inplace( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = - cu::ternary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::ternary_g_nd, num_blocks, block_dims, a.data(), @@ -144,11 +141,9 @@ void ternary_op_gpu_inplace( const_param(c_strides)); }); } else { - auto kernel = cu::ternary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large()); + auto [num_blocks, block_dims] = get_launch_args(out, large()); encoder.add_kernel_node( - kernel, + cu::ternary_g, num_blocks, block_dims, a.data(), @@ -167,16 +162,10 @@ void ternary_op_gpu_inplace( dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(DType); - auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, - out.data_size(), - out.shape(), - out.strides(), - large(), - N_READS); + out.data_size(), out.shape(), out.strides(), large(), N_READS); encoder.add_kernel_node( - kernel, + cu::ternary_v, num_blocks, block_dims, a.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 68d04b9eaf..9c8db9a89e 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -129,16 +129,10 @@ void unary_op_gpu_inplace( using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; - auto kernel = cu::unary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, - out.data_size(), - out.shape(), - out.strides(), - large, - N_READS); + out.data_size(), out.shape(), out.strides(), large, N_READS); encoder.add_kernel_node( - kernel, + cu::unary_v, num_blocks, block_dims, in.data(), @@ -147,10 +141,9 @@ void unary_op_gpu_inplace( } else { using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); - auto kernel = cu::unary_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = get_launch_args(out, large); encoder.add_kernel_node( - kernel, + cu::unary_g, num_blocks, block_dims, in.data(),