// Copyright © 2025 Apple Inc. #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (int i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[0], b[0]); } } else { AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a[0], b[0]); } store_vector(out, index, out_vec); } } template __global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[0], b[i]); } } else { auto b_vec = load_vector(b, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a[0], b_vec[i]); } store_vector(out, index, out_vec); } } template __global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[i], b[0]); } } else { auto a_vec = load_vector(a, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b[0]); } store_vector(out, index, out_vec); } } template __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { out[i] = Op{}(a[i], b[i]); } } else { auto a_vec = load_vector(a, index); auto b_vec = load_vector(b, index); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i]); } store_vector(out, index, out_vec); } } template < typename Op, typename In, typename Out, typename IdxT, int NDIM, int N_READS> __global__ void binary_g_nd( const In* a, const In* b, Out* out, IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[NDIM - 1]; auto a_stride_x = a_strides[NDIM - 1]; auto b_stride_x = b_strides[NDIM - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc_nd( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template __global__ void binary_g( const In* a, const In* b, Out* out, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto a_stride_x = a_strides[ndim - 1]; auto b_stride_x = b_strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data(), ndim); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { out_vec[i] = Op{}(a_vec[i], b_vec[i]); } store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template constexpr bool supports_binary_op() { if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v; } if (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_same_v; } if (std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v) { return std::is_same_v && is_inexact_v; } if (std::is_same_v) { return std::is_same_v && is_floating_v; } if (std::is_same_v || std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_integral_v; } if (std::is_same_v || std::is_same_v) { return std::is_same_v && std::is_integral_v && !std::is_same_v; } return false; } } // namespace cu template void binary_op_gpu_inplace( const std::vector& inputs, array& out, const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; Shape shape; std::vector strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, out); auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out.size() / dim0; if (dim0 >= 4) { work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::binary_g_nd< Op, InType, OutType, IdxT, dims_constant(), 1>; if (work_per_thread == 4) { kernel = cu::binary_g_nd< Op, InType, OutType, IdxT, dims_constant(), 4>; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, 0, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { auto kernel = cu::binary_g; if (work_per_thread == 4) { kernel = cu::binary_g; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, 0, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), rest, const_param(shape), const_param(a_strides), const_param(b_strides), ndim); } }); } else { dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; } else if (bopt == BinaryOpType::VectorScalar) { kernel = cu::binary_vs; } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large(), N_READS); encoder.add_kernel_node( kernel, num_blocks, block_dims, 0, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), out.data_size()); }); } } else { throw std::runtime_error(fmt::format( "Can not do binary op {} on inputs of {} with result of {}.", op, dtype_to_string(a.dtype()), dtype_to_string(out.dtype()))); } }); }); } template void binary_op_gpu( const std::vector& inputs, array& out, const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); auto& encoder = cu::get_command_encoder(s); set_binary_op_output_data(a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder.stream()); }); binary_op_gpu_inplace(inputs, out, op, s); } #define BINARY_GPU(func) \ void func::eval_gpu(const std::vector& inputs, array& out) { \ nvtx3::scoped_range r(#func "::eval_gpu"); \ auto& s = out.primitive().stream(); \ binary_op_gpu(inputs, out, name(), s); \ } } // namespace mlx::core