diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 29d4803a8..7a6883578 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -190,10 +190,7 @@ bool execute_plan( cudnnSetStream(handle, encoder.stream()); #if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API - cudaGraph_t graph; - cudaGraphCreate(&graph, 0); - std::unique_ptr graph_freer( - &graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); + CudaGraph graph(encoder.device()); if (cudnnBackendPopulateCudaGraph( handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) != CUDNN_STATUS_SUCCESS) { diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 96b07502f..371ae020c 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { } CommandEncoder::CaptureContext::~CaptureContext() { - CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); - std::unique_ptr graph_freer( - &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); }); + graph.end_capture(enc.stream()); if (discard) { return; } @@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } CommandEncoder::CommandEncoder(Device& d) - : device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) { - CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); -} + : device_(d), + stream_(d), + graph_(d), + graph_cache_(cuda_graph_cache_size()) {} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); @@ -311,8 +310,7 @@ void CommandEncoder::commit() { to_nodes_.clear(); graph_key_.clear(); node_map_.clear(); - CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); - CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + graph_ = CudaGraph(device_); } // Put completion handlers in a batch. diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 5eb7fd4c1..7b0ff5629 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -21,7 +21,7 @@ class CommandEncoder { struct CaptureContext { CaptureContext(CommandEncoder& enc); ~CaptureContext(); - cudaGraph_t graph; + CudaGraph graph; CommandEncoder& enc; bool discard{false}; }; @@ -115,7 +115,7 @@ class CommandEncoder { Device& device_; CudaStream stream_; - cudaGraph_t graph_; + CudaGraph graph_; Worker worker_; char node_count_{0}; char graph_node_count_{0}; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 88940a234..09894d4ca 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -8,36 +8,6 @@ namespace mlx::core { -CudaStream::CudaStream(cu::Device& device) { - device.make_current(); - CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); -} - -CudaStream::~CudaStream() { - CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); -} - -CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {} - -CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) { - other.handle_ = nullptr; -}; - -CudaGraphExec::~CudaGraphExec() { - reset(); -} - -void CudaGraphExec::instantiate(cudaGraph_t graph) { - CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); -} - -void CudaGraphExec::reset() { - if (handle_ != nullptr) { - CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_)); - handle_ = nullptr; - } -} - void check_cublas_error(const char* name, cublasStatus_t err) { if (err != CUBLAS_STATUS_SUCCESS) { // TODO: Use cublasGetStatusString when it is widely available. @@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { } } +CudaGraph::CudaGraph(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0)); +} + +void CudaGraph::end_capture(cudaStream_t stream) { + assert(handle_ == nullptr); + CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); +} + +void CudaGraphExec::instantiate(cudaGraph_t graph) { + assert(handle_ == nullptr); + CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); +} + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking)); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 555e15065..e811d5e6c 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -16,44 +16,6 @@ class Device; struct Dtype; -// Cuda stream managed with RAII. -class CudaStream { - public: - explicit CudaStream(cu::Device& device); - ~CudaStream(); - - CudaStream(const CudaStream&) = delete; - CudaStream& operator=(const CudaStream&) = delete; - - operator cudaStream_t() const { - return stream_; - } - - private: - cudaStream_t stream_; -}; - -// Move-able RAII handle of cudaGraphExec_t. -class CudaGraphExec { - public: - CudaGraphExec(cudaGraphExec_t handle = nullptr); - CudaGraphExec(CudaGraphExec&& other); - ~CudaGraphExec(); - - CudaGraphExec(const CudaGraphExec&) = delete; - CudaGraphExec& operator=(const CudaGraphExec&) = delete; - - void instantiate(cudaGraph_t graph); - void reset(); - - operator cudaGraphExec_t() const { - return handle_; - } - - private: - cudaGraphExec_t handle_; -}; - // Throw exception if the cuda API does not succeed. void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err); @@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err); // Convert Dtype to CUDA C++ types. const char* dtype_to_cuda_type(const Dtype& dtype); +// Base class for RAII managed CUDA resources. +template +class CudaHandle { + public: + CudaHandle(Handle handle = nullptr) : handle_(handle) {} + + CudaHandle(CudaHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~CudaHandle() { + reset(); + } + + CudaHandle(const CudaHandle&) = delete; + CudaHandle& operator=(const CudaHandle&) = delete; + + CudaHandle& operator=(CudaHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +// Wrappers of CUDA resources. +class CudaGraph : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaGraph(cu::Device& device); + void end_capture(cudaStream_t stream); +}; + +class CudaGraphExec : public CudaHandle { + public: + void instantiate(cudaGraph_t graph); +}; + +class CudaStream : public CudaHandle { + public: + explicit CudaStream(cu::Device& device); +}; + } // namespace mlx::core