// Copyright © 2025 Apple Inc. #include "mlx/backend/common/binary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void binary_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto out = Op{}(a[0], b[0]); out_a[0] = out[0]; out_b[0] = out[1]; } } template __global__ void binary_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto out = Op{}(a[0], b[index]); out_a[index] = out[0]; out_b[index] = out[1]; } } template __global__ void binary_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto out = Op{}(a[index], b[0]); out_a[index] = out[0]; out_b[index] = out[1]; } } template __global__ void binary_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto out = Op{}(a[index], b[index]); out_a[index] = out[0]; out_b[index] = out[1]; } } template __global__ void binary_g_nd( const In* a, const In* b, Out* out_a, Out* out_b, IdxT size, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto [a_idx, b_idx] = elem_to_loc_nd( index, shape.data(), a_strides.data(), b_strides.data()); auto out = Op{}(a[a_idx], b[b_idx]); out_a[index] = out[0]; out_b[index] = out[1]; } } template __global__ void binary_g( const In* a, const In* b, Out* out_a, Out* out_b, IdxT size, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto [a_idx, b_idx] = elem_to_loc_4d( index, shape.data(), a_strides.data(), b_strides.data(), ndim); auto out = Op{}(a[a_idx], b[b_idx]); out_a[index] = out[0]; out_b[index] = out[1]; } } template constexpr bool supports_binary_op() { if (std::is_same_v) { return std::is_same_v && (std::is_integral_v || is_floating_v); } return false; } } // namespace cu template void binary_op_gpu_inplace( const std::vector& inputs, std::vector& outputs, std::string_view op, const Stream& s) { assert(inputs.size() > 1); const auto& a = inputs[0]; const auto& b = inputs[1]; auto& out_a = outputs[0]; auto& out_b = outputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, out_a, bopt); set_binary_op_output_data(a, b, out_b, bopt); if (out_a.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out_a); encoder.set_output_array(out_b); encoder.launch_kernel([&](cudaStream_t stream) { MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, { if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); auto& a_strides = strides[0]; auto& b_strides = strides[1]; bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX; MLX_SWITCH_BOOL(large, LARGE, { using IdxT = std::conditional_t; int ndim = shape.size(); if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { auto kernel = &cu::binary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large); kernel<<>>( a.data(), b.data(), out_a.data(), out_b.data(), out_a.size(), const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { auto kernel = cu::binary_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out_a, large); kernel<<>>( a.data(), b.data(), out_a.data(), out_b.data(), out_a.size(), const_param(shape), const_param(a_strides), const_param(b_strides), ndim); } }); } else { MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, { using IdxT = std::conditional_t; auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; } else if (bopt == BinaryOpType::VectorScalar) { kernel = cu::binary_vs; } else if (bopt == BinaryOpType::VectorVector) { kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( kernel, out_a.data_size(), out_a.shape(), out_a.strides(), LARGE); kernel<<>>( a.data(), b.data(), out_a.data(), out_b.data(), out_a.data_size()); }); } } else { throw std::runtime_error(fmt::format( "Can not do binary op {} on inputs of {} with result of {}.", op, dtype_to_string(a.dtype()), dtype_to_string(out_a.dtype()))); } }); }); }); } template void binary_op_gpu( const std::vector& inputs, std::vector& outputs, std::string_view op, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto bopt = get_binary_op_type(a, b); set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[1], bopt); binary_op_gpu_inplace(inputs, outputs, op, s); } void DivMod::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("DivMod::eval_gpu"); auto& s = outputs[0].primitive().stream(); binary_op_gpu(inputs, outputs, get_primitive_string(this), s); } } // namespace mlx::core