// Copyright © 2025 Apple Inc. #include "mlx/backend/common/ternary.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/ternary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include #include namespace mlx::core { namespace cu { namespace cg = cooperative_groups; template __global__ void ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { out[index] = Op{}(a[index], b[index], c[index]); } } template __global__ void ternary_g_nd( const bool* a, const T* b, const T* c, T* out, IdxT size, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides, const __grid_constant__ cuda::std::array c_strides) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( index, shape.data(), a_strides.data(), b_strides.data(), c_strides.data()); out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); } } template __global__ void ternary_g( const bool* a, const T* b, const T* c, T* out, IdxT size, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, const __grid_constant__ Strides c_strides, int ndim) { IdxT index = cg::this_grid().thread_rank(); if (index < size) { auto [a_idx, b_idx, c_idx] = elem_to_loc_4d( index, shape.data(), a_strides.data(), b_strides.data(), c_strides.data(), ndim); out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); } } } // namespace cu template void ternary_op_gpu_inplace( const std::vector& inputs, array& out, const Stream& s) { const auto& a = inputs[0]; const auto& b = inputs[1]; const auto& c = inputs[2]; if (out.size() == 0) { return; } auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(c); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(out.dtype(), [&](auto type_tag) { using DType = cuda_type_t; auto topt = get_ternary_op_type(a, b, c); if (topt == TernaryOpType::General) { dispatch_bool( a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; Shape shape; std::vector strides; std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); auto& a_strides = strides[0]; auto& b_strides = strides[1]; auto& c_strides = strides[2]; int ndim = shape.size(); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { auto kernel = cu::ternary_g_nd; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); kernel<<>>( a.data(), b.data(), c.data(), out.data(), out.size(), const_param(shape), const_param(a_strides), const_param(b_strides), const_param(c_strides)); }); } else { auto kernel = cu::ternary_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large()); kernel<<>>( a.data(), b.data(), c.data(), out.data(), out.data_size(), const_param(shape), const_param(a_strides), const_param(b_strides), const_param(c_strides), ndim); } }); } else { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { using IdxT = std::conditional_t; auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( a.data(), b.data(), c.data(), out.data(), out.data_size()); }); } }); }); } template void ternary_op_gpu( const std::vector& inputs, array& out, const Stream& s) { auto& a = inputs[0]; auto& b = inputs[1]; auto& c = inputs[2]; auto topt = get_ternary_op_type(a, b, c); set_ternary_op_output_data(a, b, c, out, topt); ternary_op_gpu_inplace(inputs, out, s); } void Select::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("select::eval_gpu"); auto& s = out.primitive().stream(); ternary_op_gpu(inputs, out, s); } } // namespace mlx::core