Files
mlx/mlx/backend/cuda/kernel_utils.cuh

149 lines
3.9 KiB
Plaintext
Raw Normal View History

2025-05-07 13:26:46 +09:00
// Copyright © 2025 Apple Inc.
// This file includes host-only utilities for writing CUDA kernels, the
// difference from backend/cuda/device/utils.cuh is that the latter file only
// include device-only code.
2025-05-29 22:48:30 +09:00
2025-05-07 13:26:46 +09:00
#pragma once
2025-07-01 01:33:44 -07:00
#include <type_traits>
2025-05-29 22:48:30 +09:00
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/utils.cuh"
2025-05-29 22:48:30 +09:00
#include <cuda.h>
2025-05-07 13:26:46 +09:00
#include <cuda_bf16.h>
#include <cuda_fp16.h>
2025-06-09 22:45:08 +09:00
#include <fmt/format.h>
2025-06-10 22:37:40 +09:00
#include <cuda/cmath>
2025-05-07 13:26:46 +09:00
namespace mlx::core {
2025-07-01 01:33:44 -07:00
template <typename F>
void dispatch_1_2_3(int n, F&& f) {
switch (n) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
2025-06-10 22:37:40 +09:00
}
2025-07-01 01:33:44 -07:00
}
template <typename F>
void dispatch_bool(bool v, F&& f) {
if (v) {
f(std::true_type{});
} else {
f(std::false_type{});
2025-06-10 22:37:40 +09:00
}
2025-07-01 01:33:44 -07:00
}
template <typename F>
void dispatch_block_dim(int threads, F&& f) {
if (threads <= WARP_SIZE) {
f(std::integral_constant<int, WARP_SIZE>{});
} else if (threads <= WARP_SIZE * 2) {
f(std::integral_constant<int, WARP_SIZE * 2>{});
} else if (threads <= WARP_SIZE * 4) {
f(std::integral_constant<int, WARP_SIZE * 4>{});
} else if (threads <= WARP_SIZE * 8) {
f(std::integral_constant<int, WARP_SIZE * 8>{});
} else if (threads <= WARP_SIZE * 16) {
f(std::integral_constant<int, WARP_SIZE * 16>{});
} else {
f(std::integral_constant<int, WARP_SIZE * 32>{});
2025-06-12 03:22:25 +09:00
}
2025-07-01 01:33:44 -07:00
}
2025-06-12 03:22:25 +09:00
2025-05-07 13:26:46 +09:00
// Maps CPU types to CUDA types.
template <typename T>
struct CTypeToCudaType {
using type = T;
};
template <>
struct CTypeToCudaType<float16_t> {
using type = __half;
};
template <>
struct CTypeToCudaType<bfloat16_t> {
using type = __nv_bfloat16;
};
template <>
struct CTypeToCudaType<complex64_t> {
using type = cu::complex64_t;
2025-05-07 13:26:46 +09:00
};
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
2025-06-09 22:45:08 +09:00
// Type traits for detecting floating numbers.
template <typename T>
inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Type traits for detecting complex numbers.
template <typename T>
inline constexpr bool is_complex_v = cuda::std::is_same_v<T, complex64_t> ||
cuda::std::is_same_v<T, complex128_t>;
// Type traits for detecting complex or real floating point numbers.
template <typename T>
inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
2025-06-09 22:45:08 +09:00
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
2025-06-09 22:45:08 +09:00
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
}
cuda::std::array<T, NDIM> result;
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
2025-05-29 22:48:30 +09:00
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
dim3 get_2d_grid_dims(
const Shape& shape,
const Strides& strides,
size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
2025-05-29 22:48:30 +09:00
2025-09-29 08:59:25 -07:00
// Get the num_blocks and block_dims assuming each thread handles
// |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args(
size_t size,
const Shape& shape,
const Strides& strides,
2025-06-10 22:37:40 +09:00
bool large,
2025-09-29 08:59:25 -07:00
int work_per_thread = 1,
uint max_block_dim = 1024);
2025-06-10 22:37:40 +09:00
2025-09-29 08:59:25 -07:00
inline std::tuple<dim3, uint> get_launch_args(
const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args(
2025-09-29 08:59:25 -07:00
arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
}
2025-05-07 13:26:46 +09:00
} // namespace mlx::core