CUDA backend: unary ops (#2158)

This commit is contained in:
Cheng
2025-06-09 22:45:08 +09:00
committed by GitHub
parent 5866b3857b
commit f8bad60609
13 changed files with 1074 additions and 70 deletions

View File

@@ -7,10 +7,12 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/kernels/utils.cuh"
#include <cuComplex.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <fmt/format.h>
namespace mlx::core {
@@ -38,6 +40,24 @@ struct CTypeToCudaType<complex64_t> {
template <typename T>
using cuda_type_t = typename CTypeToCudaType<T>::type;
// Type traits for detecting floating numbers.
template <typename T>
inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
}
cuda::std::array<T, NDIM> result;
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);