Files
mlx/mlx/backend/cuda/kernel_utils.cuh
2025-06-09 06:45:08 -07:00

70 lines
1.8 KiB
Plaintext

// Copyright © 2025 Apple Inc.
// This file includes host-only utilies for writing CUDA kernels, the difference
// from backend/cuda/kernels/utils.cuh is that the latter file only include
// device-only code.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/kernels/utils.cuh"
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <fmt/format.h>
namespace mlx::core {
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cuComplex;
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
// Type traits for detecting floating numbers.
template <typename T>
inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
}
cuda::std::array<T, NDIM> result;
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
dim3 get_2d_grid_dims(
const Shape& shape,
const Strides& strides,
size_t divisor);
} // namespace mlx::core