diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 168b390e3..629758197 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -19,7 +19,7 @@ namespace cg = cooperative_groups; template __global__ void -binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); int remaining = size - index * N_READS; if (remaining <= 0) { @@ -50,7 +50,7 @@ binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { template __global__ void -binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); int remaining = size - index * N_READS; if (remaining <= 0) { @@ -83,7 +83,7 @@ binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { template __global__ void -binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); int remaining = size - index * N_READS; if (remaining <= 0) { @@ -116,7 +116,7 @@ binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { template __global__ void -binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { +binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); int remaining = size - index * N_READS; if (remaining <= 0) { @@ -149,7 +149,7 @@ binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { } template -__global__ void binary_g_nd( +__global__ void binary_two_g_nd( const In* a, const In* b, Out* out_a, @@ -169,7 +169,7 @@ __global__ void binary_g_nd( } template -__global__ void binary_g( +__global__ void binary_two_g( const In* a, const In* b, Out* out_a, @@ -190,7 +190,7 @@ __global__ void binary_g( } template -constexpr bool supports_binary_op() { +constexpr bool supports_binary_two_op() { if (std::is_same_v) { return std::is_same_v && (std::is_integral_v || is_floating_v); @@ -201,7 +201,7 @@ constexpr bool supports_binary_op() { } // namespace cu template -void binary_op_gpu_inplace( +void binary_two_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, std::string_view op, @@ -228,7 +228,7 @@ void binary_op_gpu_inplace( dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); - if constexpr (cu::supports_binary_op()) { + if constexpr (cu::supports_binary_two_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; @@ -248,8 +248,12 @@ void binary_op_gpu_inplace( int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto kernel = cu:: - binary_g_nd; + auto kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); encoder.add_kernel_node( @@ -266,7 +270,7 @@ void binary_op_gpu_inplace( const_param(b_strides)); }); } else { - auto kernel = cu::binary_g; + auto kernel = cu::binary_two_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large()); encoder.add_kernel_node( @@ -289,13 +293,13 @@ void binary_op_gpu_inplace( using IdxT = std::conditional_t; // TODO: Choose optimized value based on type size. constexpr int N_READS = 4; - auto kernel = cu::binary_ss; + auto kernel = cu::binary_two_ss; if (bopt == BinaryOpType::ScalarVector) { - kernel = cu::binary_sv; + kernel = cu::binary_two_sv; } else if (bopt == BinaryOpType::VectorScalar) { - kernel = cu::binary_vs; + kernel = cu::binary_two_vs; } else if (bopt == BinaryOpType::VectorVector) { - kernel = cu::binary_vv; + kernel = cu::binary_two_vv; } auto [num_blocks, block_dims] = get_launch_args( kernel, @@ -327,7 +331,7 @@ void binary_op_gpu_inplace( } template -void binary_op_gpu( +void binary_two_op_gpu( const std::vector& inputs, std::vector& outputs, std::string_view op, @@ -337,7 +341,7 @@ void binary_op_gpu( auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[1], bopt); - binary_op_gpu_inplace(inputs, outputs, op, s); + binary_two_op_gpu_inplace(inputs, outputs, op, s); } void DivMod::eval_gpu( @@ -345,7 +349,7 @@ void DivMod::eval_gpu( std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); - binary_op_gpu(inputs, outputs, get_primitive_string(this), s); + binary_two_op_gpu(inputs, outputs, get_primitive_string(this), s); } } // namespace mlx::core