mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* cuda graph prototype fix signal bug + start to add dependencies capture more capture more ops remaining ops fix reduce and rope deps add concurrent context try update, but not working cosistent topology order use node api use node api directly to reduce overhead fix bug use kernels in unary cache graph format fix synchronization format * comment
71 lines
1.5 KiB
C++
71 lines
1.5 KiB
C++
// Copyright © 2025 Apple Inc.
|
|
|
|
#include "mlx/backend/cuda/utils.h"
|
|
#include "mlx/backend/cuda/device.h"
|
|
#include "mlx/dtype_utils.h"
|
|
|
|
#include <fmt/format.h>
|
|
|
|
namespace mlx::core {
|
|
|
|
CudaStream::CudaStream(cu::Device& device) {
|
|
device.make_current();
|
|
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
|
}
|
|
|
|
CudaStream::~CudaStream() {
|
|
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
|
}
|
|
|
|
void check_cuda_error(const char* name, cudaError_t err) {
|
|
if (err != cudaSuccess) {
|
|
throw std::runtime_error(
|
|
fmt::format("{} failed: {}", name, cudaGetErrorString(err)));
|
|
}
|
|
}
|
|
|
|
void check_cuda_error(const char* name, CUresult err) {
|
|
if (err != CUDA_SUCCESS) {
|
|
const char* err_str = "Unknown error";
|
|
cuGetErrorString(err, &err_str);
|
|
throw std::runtime_error(fmt::format("{} failed: {}", name, err_str));
|
|
}
|
|
}
|
|
|
|
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
|
switch (dtype) {
|
|
case bool_:
|
|
return "bool";
|
|
case int8:
|
|
return "int8_t";
|
|
case int16:
|
|
return "int16_t";
|
|
case int32:
|
|
return "int32_t";
|
|
case int64:
|
|
return "int64_t";
|
|
case uint8:
|
|
return "uint8_t";
|
|
case uint16:
|
|
return "uint16_t";
|
|
case uint32:
|
|
return "uint32_t";
|
|
case uint64:
|
|
return "uint64_t";
|
|
case float16:
|
|
return "__half";
|
|
case bfloat16:
|
|
return "__nv_bfloat16";
|
|
case float32:
|
|
return "float";
|
|
case float64:
|
|
return "double";
|
|
case complex64:
|
|
return "cuComplex";
|
|
default:
|
|
return "unknown";
|
|
}
|
|
}
|
|
|
|
} // namespace mlx::core
|