mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
Use LRU cache for cuda graph (#2448)
* Use LRU cache for cuda graph * Remove unused destructor
This commit is contained in:
parent
8831064493
commit
aaf78f4c6b
@ -184,21 +184,11 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
||||
CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
|
||||
for (auto& [_, graph_exec] : graphs) {
|
||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
||||
}
|
||||
graphs.clear();
|
||||
}
|
||||
|
||||
CommandEncoder::~CommandEncoder() {
|
||||
clear_graphs(graph_cache_);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
@ -290,7 +280,7 @@ void CommandEncoder::commit() {
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(empty_node_count_);
|
||||
|
||||
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
|
||||
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
||||
|
||||
if (graph_exec != nullptr) {
|
||||
cudaGraphExecUpdateResult update_result;
|
||||
@ -304,22 +294,15 @@ void CommandEncoder::commit() {
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||
cudaGetLastError(); // reset error
|
||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
||||
graph_exec = nullptr;
|
||||
graph_exec.reset();
|
||||
}
|
||||
}
|
||||
if (graph_exec == nullptr) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
||||
graph_exec.instantiate(graph_);
|
||||
}
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// TODO smarter cache policy
|
||||
if (graph_cache_.size() > cuda_graph_cache_size()) {
|
||||
clear_graphs(graph_cache_);
|
||||
}
|
||||
|
||||
// Reset state
|
||||
node_count_ = 0;
|
||||
graph_node_count_ = 0;
|
||||
|
@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
@ -31,7 +32,6 @@ class CommandEncoder {
|
||||
};
|
||||
|
||||
explicit CommandEncoder(Device& d);
|
||||
~CommandEncoder();
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
@ -126,7 +126,7 @@ class CommandEncoder {
|
||||
std::string graph_key_;
|
||||
std::vector<GraphNode> concurrent_nodes_;
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
|
||||
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||
std::vector<std::uintptr_t> active_deps_;
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
@ -20,7 +21,11 @@ class LRUCache {
|
||||
using const_iterator = typename list_type::const_iterator;
|
||||
using map_type = M<K, iterator>;
|
||||
|
||||
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
|
||||
explicit LRUCache(size_t capacity) : capacity_(capacity) {
|
||||
if (capacity == 0) {
|
||||
throw std::runtime_error("LRUCache requires capacity > 0.");
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return map_.size();
|
||||
@ -84,6 +89,14 @@ class LRUCache {
|
||||
return vlist_.erase(pos);
|
||||
}
|
||||
|
||||
V& operator[](const K& key) {
|
||||
auto it = find(key);
|
||||
if (it == end()) {
|
||||
it = emplace(key, V{}).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
void trim() {
|
||||
while (map_.size() > capacity_) {
|
||||
|
@ -17,6 +17,27 @@ CudaStream::~CudaStream() {
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
|
||||
|
||||
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
|
||||
other.handle_ = nullptr;
|
||||
};
|
||||
|
||||
CudaGraphExec::~CudaGraphExec() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
||||
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
||||
}
|
||||
|
||||
void CudaGraphExec::reset() {
|
||||
if (handle_ != nullptr) {
|
||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||
// TODO: Use cublasGetStatusString when it is widely available.
|
||||
|
@ -33,6 +33,27 @@ class CudaStream {
|
||||
cudaStream_t stream_;
|
||||
};
|
||||
|
||||
// Move-able RAII handle of cudaGraphExec_t.
|
||||
class CudaGraphExec {
|
||||
public:
|
||||
CudaGraphExec(cudaGraphExec_t handle = nullptr);
|
||||
CudaGraphExec(CudaGraphExec&& other);
|
||||
~CudaGraphExec();
|
||||
|
||||
CudaGraphExec(const CudaGraphExec&) = delete;
|
||||
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
|
||||
|
||||
void instantiate(cudaGraph_t graph);
|
||||
void reset();
|
||||
|
||||
operator cudaGraphExec_t() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaGraphExec_t handle_;
|
||||
};
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
|
Loading…
Reference in New Issue
Block a user