From aaf78f4c6b618daa0fd74697a47cc467920e958f Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 2 Aug 2025 21:28:57 +0900 Subject: [PATCH] Use LRU cache for cuda graph (#2448) * Use LRU cache for cuda graph * Remove unused destructor --- mlx/backend/cuda/device.cpp | 27 +++++---------------------- mlx/backend/cuda/device.h | 4 ++-- mlx/backend/cuda/lru_cache.h | 15 ++++++++++++++- mlx/backend/cuda/utils.cpp | 21 +++++++++++++++++++++ mlx/backend/cuda/utils.h | 21 +++++++++++++++++++++ 5 files changed, 63 insertions(+), 25 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 7a454e7d7f..96b07502f7 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -184,21 +184,11 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } } -CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { +CommandEncoder::CommandEncoder(Device& d) + : device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); } -void clear_graphs(std::unordered_map& graphs) { - for (auto& [_, graph_exec] : graphs) { - CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); - } - graphs.clear(); -} - -CommandEncoder::~CommandEncoder() { - clear_graphs(graph_cache_); -} - void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); } @@ -290,7 +280,7 @@ void CommandEncoder::commit() { graph_key_ += "."; graph_key_ += std::to_string(empty_node_count_); - cudaGraphExec_t& graph_exec = graph_cache_[graph_key_]; + CudaGraphExec& graph_exec = graph_cache_[graph_key_]; if (graph_exec != nullptr) { cudaGraphExecUpdateResult update_result; @@ -304,22 +294,15 @@ void CommandEncoder::commit() { #endif // CUDART_VERSION >= 12000 if (update_result != cudaGraphExecUpdateSuccess) { cudaGetLastError(); // reset error - CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); - graph_exec = nullptr; + graph_exec.reset(); } } if (graph_exec == nullptr) { - CHECK_CUDA_ERROR( - cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); + graph_exec.instantiate(graph_); } device_.make_current(); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); - // TODO smarter cache policy - if (graph_cache_.size() > cuda_graph_cache_size()) { - clear_graphs(graph_cache_); - } - // Reset state node_count_ = 0; graph_node_count_ = 0; diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index ea932082c5..5eb7fd4c15 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/array.h" +#include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/worker.h" #include "mlx/stream.h" @@ -31,7 +32,6 @@ class CommandEncoder { }; explicit CommandEncoder(Device& d); - ~CommandEncoder(); CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; @@ -126,7 +126,7 @@ class CommandEncoder { std::string graph_key_; std::vector concurrent_nodes_; std::vector> temporaries_; - std::unordered_map graph_cache_; + LRUCache graph_cache_; std::vector active_deps_; std::vector active_outputs_; std::unordered_map node_map_; diff --git a/mlx/backend/cuda/lru_cache.h b/mlx/backend/cuda/lru_cache.h index c8df2fa93c..79fca0ae71 100644 --- a/mlx/backend/cuda/lru_cache.h +++ b/mlx/backend/cuda/lru_cache.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include #include @@ -20,7 +21,11 @@ class LRUCache { using const_iterator = typename list_type::const_iterator; using map_type = M; - explicit LRUCache(size_t capacity) : capacity_(capacity) {} + explicit LRUCache(size_t capacity) : capacity_(capacity) { + if (capacity == 0) { + throw std::runtime_error("LRUCache requires capacity > 0."); + } + } size_t size() const { return map_.size(); @@ -84,6 +89,14 @@ class LRUCache { return vlist_.erase(pos); } + V& operator[](const K& key) { + auto it = find(key); + if (it == end()) { + it = emplace(key, V{}).first; + } + return it->second; + } + private: void trim() { while (map_.size() > capacity_) { diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index baab3b2a53..88940a2341 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -17,6 +17,27 @@ CudaStream::~CudaStream() { CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); } +CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {} + +CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) { + other.handle_ = nullptr; +}; + +CudaGraphExec::~CudaGraphExec() { + reset(); +} + +void CudaGraphExec::instantiate(cudaGraph_t graph) { + CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); +} + +void CudaGraphExec::reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_)); + handle_ = nullptr; + } +} + void check_cublas_error(const char* name, cublasStatus_t err) { if (err != CUBLAS_STATUS_SUCCESS) { // TODO: Use cublasGetStatusString when it is widely available. diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index f2b6b16cb8..555e150656 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -33,6 +33,27 @@ class CudaStream { cudaStream_t stream_; }; +// Move-able RAII handle of cudaGraphExec_t. +class CudaGraphExec { + public: + CudaGraphExec(cudaGraphExec_t handle = nullptr); + CudaGraphExec(CudaGraphExec&& other); + ~CudaGraphExec(); + + CudaGraphExec(const CudaGraphExec&) = delete; + CudaGraphExec& operator=(const CudaGraphExec&) = delete; + + void instantiate(cudaGraph_t graph); + void reset(); + + operator cudaGraphExec_t() const { + return handle_; + } + + private: + cudaGraphExec_t handle_; +}; + // Throw exception if the cuda API does not succeed. void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err);