diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 9b12d84a9..1567feafd 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -34,6 +34,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) diff --git a/mlx/backend/cuda/device/ternary_ops.cuh b/mlx/backend/cuda/device/ternary_ops.cuh new file mode 100644 index 000000000..d1d008ac5 --- /dev/null +++ b/mlx/backend/cuda/device/ternary_ops.cuh @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::cu { + +struct Select { + template + __device__ T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index a1d387201..6f9851c94 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -162,6 +162,27 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( return cuda::std::make_tuple(a_loc, b_loc); } +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides) { + IdxT a_loc = 0; + IdxT b_loc = 0; + IdxT c_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + c_loc += dim_idx * c_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc, c_loc); +} + // Optimized version when ndim is larger than 4. template inline __host__ __device__ IdxT @@ -191,6 +212,26 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( return cuda::std::make_tuple(a_loc, b_loc); } +template +inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides, + int ndim) { + auto [a_loc, b_loc, c_loc] = + elem_to_loc_nd<3>(elem, shape, a_strides, b_strides, c_strides); + for (int i = ndim - 1; i >= 3; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * a_strides[i]; + b_loc += dim_idx * b_strides[i]; + c_loc += dim_idx * c_strides[i]; + elem /= shape[i]; + } + return cuda::std::make_tuple(a_loc, b_loc, c_loc); +} + /////////////////////////////////////////////////////////////////////////////// // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 95ea44f94..eb451f49d 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -91,7 +91,6 @@ NO_GPU(QuantizedMatmul) NO_GPU(Scan) NO_GPU(Scatter) NO_GPU(ScatterAxis) -NO_GPU(Select) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky) diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu new file mode 100644 index 000000000..0a2c67f76 --- /dev/null +++ b/mlx/backend/cuda/ternary.cu @@ -0,0 +1,181 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/ternary_ops.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void +ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + out[index] = Op{}(a[index], b[index], c[index]); + } +} + +template +__global__ void ternary_g_nd( + const bool* a, + const T* b, + const T* c, + T* out, + IdxT size, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array a_strides, + const __grid_constant__ cuda::std::array b_strides, + const __grid_constant__ cuda::std::array c_strides) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( + index, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data()); + out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + } +} + +template +__global__ void ternary_g( + const bool* a, + const T* b, + const T* c, + T* out, + IdxT size, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides a_strides, + const __grid_constant__ Strides b_strides, + const __grid_constant__ Strides c_strides, + int ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index < size) { + auto [a_idx, b_idx, c_idx] = elem_to_loc_4d( + index, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data(), + ndim); + out[index] = Op{}(a[a_idx], b[b_idx]); + } +} + +} // namespace cu + +template +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + std::string_view op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + const auto& c = inputs[2]; + if (out.size() == 0) { + return; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, { + using DType = cuda_type_t; + + auto topt = get_ternary_op_type(a, b, c); + if (topt == TernaryOpType::General) { + auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + bool large = a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || + c.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = &cu::ternary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } + }); + } else { + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::ternary_v; + auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size()); + }); + } + }); + }); +} + +template +void ternary_op_gpu( + const std::vector& inputs, + array& out, + std::string_view op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + auto topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + ternary_op_gpu_inplace(inputs, out, op, s); +} + +void Select::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("select::eval_gpu"); + auto& s = out.primitive().stream(); + ternary_op_gpu(inputs, out, get_primitive_string(this), s); +} + +} // namespace mlx::core