// Copyright © 2025 Apple Inc. // This file includes host-only utilies for writing CUDA kernels, the difference // from backend/cuda/device/utils.cuh is that the latter file only include // device-only code. #pragma once #include #include "mlx/array.h" #include "mlx/backend/cuda/device/utils.cuh" #include #include #include #include #include namespace mlx::core { template void dispatch_1_2_3(int n, F&& f) { switch (n) { case 1: f(std::integral_constant{}); break; case 2: f(std::integral_constant{}); break; case 3: f(std::integral_constant{}); break; } } template void dispatch_bool(bool v, F&& f) { if (v) { f(std::true_type{}); } else { f(std::false_type{}); } } template void dispatch_block_dim(int threads, F&& f) { if (threads <= WARP_SIZE) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 2) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 4) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 8) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 16) { f(std::integral_constant{}); } else { f(std::integral_constant{}); } } // Maps CPU types to CUDA types. template struct CTypeToCudaType { using type = T; }; template <> struct CTypeToCudaType { using type = __half; }; template <> struct CTypeToCudaType { using type = __nv_bfloat16; }; template <> struct CTypeToCudaType { using type = cu::complex64_t; }; template using cuda_type_t = typename CTypeToCudaType::type; // Type traits for detecting floating numbers. template inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; // Type traits for detecting complex numbers. template inline constexpr bool is_complex_v = cuda::std::is_same_v || cuda::std::is_same_v; // Type traits for detecting complex or real floating point numbers. template inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; // Utility to copy data from vector to array in host. template inline cuda::std::array const_param(const std::vector& vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } cuda::std::array result; std::copy_n(vec.begin(), vec.size(), result.begin()); return result; } // Compute the grid and block dimensions, check backend/common/utils.h for docs. dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); dim3 get_2d_grid_dims( const Shape& shape, const Strides& strides, size_t divisor); std::pair get_grid_and_block(int dim0, int dim1, int dim2); // Return a block size that achieves maximum potential occupancy for kernel. template inline uint max_occupancy_block_dim(T kernel) { int _, block_dim; if constexpr (std::is_same_v) { CHECK_CUDA_ERROR( cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); } else { CHECK_CUDA_ERROR( cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); } return block_dim; } // Get the num_blocks and block_dims that maximize occupancy for |kernel|, // assuming each thread handles |work_per_thread| elements of |arr|. template inline std::tuple get_launch_args( T kernel, size_t size, const Shape& shape, const Strides& strides, bool large, int work_per_thread = 1) { size_t nthreads = cuda::ceil_div(size, work_per_thread); uint block_dim = max_occupancy_block_dim(kernel); if (block_dim > nthreads) { block_dim = nthreads; } dim3 num_blocks; if (large) { num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim); } else { num_blocks.x = cuda::ceil_div(nthreads, block_dim); } return std::make_tuple(num_blocks, block_dim); } template inline std::tuple get_launch_args( T kernel, const array& arr, bool large, int work_per_thread = 1) { return get_launch_args( kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread); } } // namespace mlx::core