mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 20:26:40 +08:00
Use LRU cache for cuda graph (#2448)
* Use LRU cache for cuda graph * Remove unused destructor
This commit is contained in:
parent
8831064493
commit
aaf78f4c6b
@ -184,21 +184,11 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) {
|
CommandEncoder::CommandEncoder(Device& d)
|
||||||
|
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
|
||||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void clear_graphs(std::unordered_map<std::string, cudaGraphExec_t>& graphs) {
|
|
||||||
for (auto& [_, graph_exec] : graphs) {
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
|
||||||
}
|
|
||||||
graphs.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandEncoder::~CommandEncoder() {
|
|
||||||
clear_graphs(graph_cache_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
worker_.add_task(std::move(task));
|
worker_.add_task(std::move(task));
|
||||||
}
|
}
|
||||||
@ -290,7 +280,7 @@ void CommandEncoder::commit() {
|
|||||||
graph_key_ += ".";
|
graph_key_ += ".";
|
||||||
graph_key_ += std::to_string(empty_node_count_);
|
graph_key_ += std::to_string(empty_node_count_);
|
||||||
|
|
||||||
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
|
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
||||||
|
|
||||||
if (graph_exec != nullptr) {
|
if (graph_exec != nullptr) {
|
||||||
cudaGraphExecUpdateResult update_result;
|
cudaGraphExecUpdateResult update_result;
|
||||||
@ -304,22 +294,15 @@ void CommandEncoder::commit() {
|
|||||||
#endif // CUDART_VERSION >= 12000
|
#endif // CUDART_VERSION >= 12000
|
||||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||||
cudaGetLastError(); // reset error
|
cudaGetLastError(); // reset error
|
||||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
|
graph_exec.reset();
|
||||||
graph_exec = nullptr;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (graph_exec == nullptr) {
|
if (graph_exec == nullptr) {
|
||||||
CHECK_CUDA_ERROR(
|
graph_exec.instantiate(graph_);
|
||||||
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
|
|
||||||
}
|
}
|
||||||
device_.make_current();
|
device_.make_current();
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
|
||||||
// TODO smarter cache policy
|
|
||||||
if (graph_cache_.size() > cuda_graph_cache_size()) {
|
|
||||||
clear_graphs(graph_cache_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
node_count_ = 0;
|
node_count_ = 0;
|
||||||
graph_node_count_ = 0;
|
graph_node_count_ = 0;
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
@ -31,7 +32,6 @@ class CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
explicit CommandEncoder(Device& d);
|
explicit CommandEncoder(Device& d);
|
||||||
~CommandEncoder();
|
|
||||||
|
|
||||||
CommandEncoder(const CommandEncoder&) = delete;
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
@ -126,7 +126,7 @@ class CommandEncoder {
|
|||||||
std::string graph_key_;
|
std::string graph_key_;
|
||||||
std::vector<GraphNode> concurrent_nodes_;
|
std::vector<GraphNode> concurrent_nodes_;
|
||||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||||
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
|
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||||
std::vector<std::uintptr_t> active_deps_;
|
std::vector<std::uintptr_t> active_deps_;
|
||||||
std::vector<std::uintptr_t> active_outputs_;
|
std::vector<std::uintptr_t> active_outputs_;
|
||||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -20,7 +21,11 @@ class LRUCache {
|
|||||||
using const_iterator = typename list_type::const_iterator;
|
using const_iterator = typename list_type::const_iterator;
|
||||||
using map_type = M<K, iterator>;
|
using map_type = M<K, iterator>;
|
||||||
|
|
||||||
explicit LRUCache(size_t capacity) : capacity_(capacity) {}
|
explicit LRUCache(size_t capacity) : capacity_(capacity) {
|
||||||
|
if (capacity == 0) {
|
||||||
|
throw std::runtime_error("LRUCache requires capacity > 0.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t size() const {
|
size_t size() const {
|
||||||
return map_.size();
|
return map_.size();
|
||||||
@ -84,6 +89,14 @@ class LRUCache {
|
|||||||
return vlist_.erase(pos);
|
return vlist_.erase(pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
V& operator[](const K& key) {
|
||||||
|
auto it = find(key);
|
||||||
|
if (it == end()) {
|
||||||
|
it = emplace(key, V{}).first;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void trim() {
|
void trim() {
|
||||||
while (map_.size() > capacity_) {
|
while (map_.size() > capacity_) {
|
||||||
|
@ -17,6 +17,27 @@ CudaStream::~CudaStream() {
|
|||||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
|
||||||
|
|
||||||
|
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
|
||||||
|
other.handle_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
CudaGraphExec::~CudaGraphExec() {
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaGraphExec::reset() {
|
||||||
|
if (handle_ != nullptr) {
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
|
||||||
|
handle_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||||
// TODO: Use cublasGetStatusString when it is widely available.
|
// TODO: Use cublasGetStatusString when it is widely available.
|
||||||
|
@ -33,6 +33,27 @@ class CudaStream {
|
|||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Move-able RAII handle of cudaGraphExec_t.
|
||||||
|
class CudaGraphExec {
|
||||||
|
public:
|
||||||
|
CudaGraphExec(cudaGraphExec_t handle = nullptr);
|
||||||
|
CudaGraphExec(CudaGraphExec&& other);
|
||||||
|
~CudaGraphExec();
|
||||||
|
|
||||||
|
CudaGraphExec(const CudaGraphExec&) = delete;
|
||||||
|
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
|
||||||
|
|
||||||
|
void instantiate(cudaGraph_t graph);
|
||||||
|
void reset();
|
||||||
|
|
||||||
|
operator cudaGraphExec_t() const {
|
||||||
|
return handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
cudaGraphExec_t handle_;
|
||||||
|
};
|
||||||
|
|
||||||
// Throw exception if the cuda API does not succeed.
|
// Throw exception if the cuda API does not succeed.
|
||||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
|
Loading…
Reference in New Issue
Block a user