// Copyright © 2025 Apple Inc. #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[0], b[0]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a[0], b[0]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template __global__ void binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[0], b[i]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { auto b_vec = load_vector(b, index); AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a[0], b_vec[i]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template __global__ void binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[i], b[0]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { auto a_vec = load_vector(a, index); AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b[0]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template __global__ void binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if ((index + 1) * N_READS > size) { for (IdxT i = index * N_READS; i < size; ++i) { auto out = Op{}(a[i], b[i]); out_a[i] = out[0]; out_b[i] = out[1]; } } else { auto a_vec = load_vector(a, index); auto b_vec = load_vector(b, index); AlignedVector out_a_vec; AlignedVector out_b_vec; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b_vec[i]); out_a_vec[i] = out[0]; out_b_vec[i] = out[1]; } store_vector(out_a, index, out_a_vec); store_vector(out_b, index, out_b_vec); } } template < typename Op, typename In, typename Out, typename IdxT, int NDIM, int N_READS> __global__ void binary_two_g_nd( const In* a, const In* b, Out* out_a, Out* out_b, IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[NDIM - 1]; auto a_stride_x = a_strides[NDIM - 1]; auto b_stride_x = b_strides[NDIM - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc_nd( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec_a; AlignedVector out_vec_b; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b_vec[i]); out_vec_a[i] = out[0]; out_vec_b[i] = out[1]; } store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } template __global__ void binary_two_g( const In* a, const In* b, Out* out_a, Out* out_b, IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { auto block = cg::this_thread_block(); auto grid = cg::this_grid(); IdxT index_rest = grid.block_index().y * block.dim_threads().y + block.thread_index().y; if (index_rest >= size_rest) { return; } auto shape_x = shape[ndim - 1]; auto a_stride_x = a_strides[ndim - 1]; auto b_stride_x = b_strides[ndim - 1]; IdxT index_x = grid.block_index().x * block.dim_threads().x + block.thread_index().x; auto [a_idx, b_idx] = elem_to_loc( index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data(), ndim); auto a_vec = load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); auto b_vec = load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); AlignedVector out_vec_a; AlignedVector out_vec_b; #pragma unroll for (int i = 0; i < N_READS; ++i) { auto out = Op{}(a_vec[i], b_vec[i]); out_vec_a[i] = out[0]; out_vec_b[i] = out[1]; } store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } template constexpr bool supports_binary_two_op() { if (std::is_same_v) { return std::is_same_v && (std::is_integral_v || is_floating_v); } return false; } } // namespace cu template void binary_two_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; auto& out_a = outputs[0]; auto& out_b = outputs[1]; auto bopt = get_binary_op_type(a, b); auto& encoder = cu::get_command_encoder(s); set_binary_op_output_data( a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); }); set_binary_op_output_data( a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); }); if (out_a.size() == 0) { return; } encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out_a); encoder.set_output_array(out_b); dispatch_all_types(a.dtype(), [&](auto in_type_tag) { dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { using CTYPE_IN = MLX_GET_TYPE(in_type_tag); using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_two_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; Shape shape; std::vector strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, out_a); auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); int work_per_thread = 1; auto dim0 = ndim > 0 ? shape.back() : 1; auto rest = out_a.size() / dim0; if (dim0 >= 4) { work_per_thread = 4; } dim0 = (dim0 + work_per_thread - 1) / work_per_thread; auto block_dims = get_block_dims(dim0, rest, 1); uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::binary_two_g_nd< Op, InType, OutType, IdxT, dims_constant(), 1>; if (work_per_thread == 4) { kernel = cu::binary_two_g_nd< Op, InType, OutType, IdxT, dims_constant(), 4>; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, 0, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { auto kernel = cu::binary_two_g; if (work_per_thread == 4) { kernel = cu::binary_two_g; } encoder.add_kernel_node( kernel, {num_blocks_x, num_blocks_y}, block_dims, 0, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), rest, const_param(shape), const_param(a_strides), const_param(b_strides), ndim); } }); } else { dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) { using IdxT = std::conditional_t; constexpr int N_READS = 16 / sizeof(InType); auto kernel = cu::binary_two_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_two_sv; } else if (bopt == BinaryOpType::VectorScalar) { kernel = cu::binary_two_vs; } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_two_vv; } auto [num_blocks, block_dims] = get_launch_args( out_a.data_size(), out_a.shape(), out_a.strides(), large(), N_READS); encoder.add_kernel_node( kernel, num_blocks, block_dims, 0, gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), out_a.data_size()); }); } } else { throw std::runtime_error(fmt::format( "Can not do binary op {} on inputs of {} with result of {}.", op, dtype_to_string(a.dtype()), dtype_to_string(out_a.dtype()))); } }); }); } template void binary_two_op_gpu( const std::vector& inputs, std::vector& outputs, const char* op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[1], bopt); binary_two_op_gpu_inplace(inputs, outputs, op, s); } void DivMod::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); binary_two_op_gpu(inputs, outputs, name(), s); } } // namespace mlx::core