// Copyright © 2025 Apple Inc. // This file must not include any host-only code, utilities that work under both // host and device can be put here. // // See more about the requirements at: // https://docs.nvidia.com/cuda/nvrtc/#language #pragma once #include "mlx/backend/cuda/device/complex.cuh" #include "mlx/backend/cuda/device/config.h" #include #include #include #include #include namespace mlx::core::cu { /////////////////////////////////////////////////////////////////////////////// // CUDA kernel utils /////////////////////////////////////////////////////////////////////////////// // To pass shape/strides to kernels via constant memory, their size must be // known at compile time. using Shape = cuda::std::array; using Strides = cuda::std::array; // Vectorized load/store. template struct alignas(sizeof(T) * N) AlignedVector { T val[N]; __device__ T& operator[](int i) { return val[i]; } __device__ T operator[](int i) const { return val[i]; } }; template inline __host__ __device__ bool is_aligned(T* x) { return (reinterpret_cast(x) % (N * sizeof(T))) == 0; } template inline __device__ AlignedVector unsafe_load_vector( const T* ptr, uint32_t offset) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } template inline __device__ AlignedVector load_vector( const T* ptr, uint32_t offset) { if (is_aligned(ptr)) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } else { AlignedVector v; #pragma unroll for (int i = 0; i < N; ++i) { v[i] = ptr[offset * N + i]; } return v; } } template inline __device__ AlignedVector load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) { if (is_aligned(ptr) && (offset + 1) * N <= size) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } else { AlignedVector v; #pragma unroll for (int i = 0; i < N; ++i) { v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback; } return v; } } template inline __device__ AlignedVector load_vector( const T* ptr, uint32_t offset, SizeT size, int64_t stride, T fallback) { if (is_aligned(ptr) && stride == 1 && (offset + 1) * N <= size) { auto* from = reinterpret_cast*>(ptr); return from[offset]; } else { AlignedVector v; #pragma unroll for (int i = 0; i < N; ++i) { v[i] = (N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback; } return v; } } template inline __device__ void unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } template inline __device__ void store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { if (is_aligned(ptr)) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } else { #pragma unroll for (int i = 0; i < N; ++i) { ptr[offset * N + i] = vec[i]; } } } template inline __device__ void store_vector( T* ptr, uint32_t offset, const AlignedVector& vec, SizeT size) { if (is_aligned(ptr) && (offset + 1) * N <= size) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } else { for (int i = 0; (offset * N + i) < size && i < N; ++i) { ptr[offset * N + i] = vec[i]; } } } template inline __device__ void store_vector( T* ptr, uint32_t offset, const AlignedVector& vec, SizeT size, int64_t stride) { if (is_aligned(ptr) && (offset + 1) * N <= size && stride == 1) { auto* to = reinterpret_cast*>(ptr); to[offset] = vec; } else { for (int i = 0; (offset * N + i) < size && i < N; ++i) { ptr[stride * (offset * N + i)] = vec[i]; } } } /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// template struct Limits { static constexpr __host__ __device__ T max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T min() { return cuda::std::numeric_limits::min(); } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { return cuda::std::numeric_limits::min(); } }; template struct Limits< T, cuda::std::enable_if_t< cuda::std::is_same_v || cuda::std::is_same_v>> { static constexpr __host__ __device__ T max() { return cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T min() { return -cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { return cuda::std::numeric_limits::lowest(); } }; // CUDA 11 does not have host side arithmetic operators for half types. template struct Limits< T, cuda::std::enable_if_t< cuda::std::is_same_v || cuda::std::is_same_v>> { static constexpr __host__ __device__ T max() { return cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T min() { #if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return -cuda::std::numeric_limits::infinity(); #else return -cuda::std::numeric_limits::infinity(); #endif } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { #if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return cuda::std::numeric_limits::lowest(); #else return cuda::std::numeric_limits::lowest(); #endif } }; template <> struct Limits { static constexpr __host__ __device__ bool max() { return true; } static constexpr __host__ __device__ bool min() { return false; } }; template struct Limits> { static constexpr __host__ __device__ complex_t max() { return {Limits::max(), Limits::max()}; } static constexpr __host__ __device__ complex_t min() { return {Limits::min(), Limits::min()}; } }; /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// template inline __host__ __device__ IdxT elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } // Optimize when the ndim is known at compile time. template inline __host__ __device__ IdxT elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) { IdxT loc = 0; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } template inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides) { IdxT a_loc = 0; IdxT b_loc = 0; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); } template inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides, const int64_t* c_strides) { IdxT a_loc = 0; IdxT b_loc = 0; IdxT c_loc = 0; #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); } template inline __host__ __device__ cuda::std::tuple elem_to_loc( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides, int ndim) { IdxT a_loc = 0; IdxT b_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); } template inline __host__ __device__ cuda::std::tuple elem_to_loc( IdxT elem, const int* shape, const int64_t* a_strides, const int64_t* b_strides, const int64_t* c_strides, int ndim) { IdxT a_loc = 0; IdxT b_loc = 0; IdxT c_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; a_loc += dim_idx * IdxT(a_strides[i]); b_loc += dim_idx * IdxT(b_strides[i]); c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); } /////////////////////////////////////////////////////////////////////////////// // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// template struct LoopedElemToLoc { int dim; LoopedElemToLoc inner_looper; OffsetT offset{0}; int index{0}; __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} __device__ void next(const int* shape, const int64_t* strides) { if (dim == 0) { return; } index++; offset += OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { index = 0; inner_looper.next(shape, strides); offset = inner_looper.offset; } } __device__ void next(int n, const int* shape, const int64_t* strides) { if (dim == 0) { return; } index += n; offset += n * OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { int extra = index - shape[dim - 1]; if (extra >= shape[dim - 1]) { inner_looper.next(1 + extra / shape[dim - 1], shape, strides); extra = extra % shape[dim - 1]; } else { inner_looper.next(shape, strides); } index = 0; offset = inner_looper.offset; if (extra > 0) { next(extra, shape, strides); } } } __device__ OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, true, OffsetT> { int dim; OffsetT offset{0}; int index{0}; __device__ LoopedElemToLoc(int dim) : dim(dim) {} __device__ void next(const int* shape, const int64_t* strides) { index++; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset += OffsetT(strides[0]); } } __device__ void next(int n, const int* shape, const int64_t* strides) { index += n; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset = index * OffsetT(strides[0]); } } __device__ OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, false, OffsetT> { OffsetT offset{0}; __device__ LoopedElemToLoc(int) {} __device__ void next(const int*, const int64_t* strides) { offset += OffsetT(strides[0]); } __device__ void next(int n, const int*, const int64_t* strides) { offset += n * OffsetT(strides[0]); } __device__ OffsetT location() { return offset; } }; } // namespace mlx::core::cu