// Copyright © 2025 Apple Inc. // This file includes host-only utilities for writing CUDA kernels, the // difference from backend/cuda/device/utils.cuh is that the latter file only // include device-only code. #pragma once #include #include "mlx/array.h" #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device/utils.cuh" #include #include #include #include #include namespace mlx::core { template void dispatch_1_2_3(int n, F&& f) { switch (n) { case 1: f(std::integral_constant{}); break; case 2: f(std::integral_constant{}); break; case 3: f(std::integral_constant{}); break; } } template void dispatch_bool(bool v, F&& f) { if (v) { f(std::true_type{}); } else { f(std::false_type{}); } } template void dispatch_block_dim(int threads, F&& f) { if (threads <= WARP_SIZE) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 2) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 4) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 8) { f(std::integral_constant{}); } else if (threads <= WARP_SIZE * 16) { f(std::integral_constant{}); } else { f(std::integral_constant{}); } } // Maps CPU types to CUDA types. template struct CTypeToCudaType { using type = T; }; template <> struct CTypeToCudaType { using type = __half; }; template <> struct CTypeToCudaType { using type = __nv_bfloat16; }; template <> struct CTypeToCudaType { using type = cu::complex64_t; }; template using cuda_type_t = typename CTypeToCudaType::type; // Type traits for detecting floating numbers. template inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v || cuda::std::is_same_v; // Type traits for detecting complex numbers. template inline constexpr bool is_complex_v = cuda::std::is_same_v || cuda::std::is_same_v; // Type traits for detecting complex or real floating point numbers. template inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; // Utility to copy data from vector to array in host. template inline cuda::std::array const_param(const SmallVector& vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); } cuda::std::array result; std::copy_n(vec.begin(), vec.size(), result.begin()); return result; } // Compute the grid and block dimensions, check backend/common/utils.h for docs. dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); dim3 get_2d_grid_dims( const Shape& shape, const Strides& strides, size_t divisor); std::pair get_grid_and_block(int dim0, int dim1, int dim2); // Get the num_blocks and block_dims assuming each thread handles // |work_per_thread| elements of |arr|. std::tuple get_launch_args( size_t size, const Shape& shape, const Strides& strides, bool large, int work_per_thread = 1, uint max_block_dim = 1024); inline std::tuple get_launch_args( const array& arr, bool large, int work_per_thread = 1, uint max_block_dim = 1024) { return get_launch_args( arr.size(), arr.shape(), arr.strides(), large, work_per_thread, max_block_dim); } } // namespace mlx::core