CUDA backend: compile (#2276)

* CUDA backend: compile

* Rename kernels/ to device/
This commit is contained in:
Cheng
2025-06-13 09:08:39 +09:00
committed by GitHub
parent f5f65ef48c
commit a4fc671d3e
27 changed files with 910 additions and 27 deletions

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/dtype_utils.h"
#include <fmt/format.h>
@@ -23,4 +24,20 @@ void check_cuda_error(const char* name, cudaError_t err) {
}
}
const char* dtype_to_cuda_type(const Dtype& dtype) {
if (dtype == float16) {
return "__half";
}
if (dtype == bfloat16) {
return "__nv_bfloat16";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (dtype == DTYPE) { \
return #CPP_TYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return nullptr;
}
} // namespace mlx::core