mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA backend: unary ops (#2158)
This commit is contained in:
@@ -7,10 +7,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/kernels/utils.cuh"
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -38,6 +40,24 @@ struct CTypeToCudaType<complex64_t> {
|
||||
template <typename T>
|
||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||
|
||||
// Type traits for detecting floating numbers.
|
||||
template <typename T>
|
||||
inline constexpr bool is_floating_v =
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
cuda::std::array<T, NDIM> result;
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||
|
||||
Reference in New Issue
Block a user