mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-08 02:21:14 +08:00
[CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329)
This commit is contained in:
parent
8917022deb
commit
0e0d9ac522
@ -15,7 +15,12 @@ namespace mlx::core {
|
||||
// This should be less than 255
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
constexpr int max_graph_cache_size = 100;
|
||||
int cuda_graph_cache_size() {
|
||||
static int cache_size = []() {
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||
}();
|
||||
return cache_size;
|
||||
}
|
||||
|
||||
namespace cu {
|
||||
|
||||
@ -252,10 +257,6 @@ void CommandEncoder::commit() {
|
||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
||||
}
|
||||
// TODO smarter cache policy
|
||||
if (graph_cache_.size() > max_graph_cache_size) {
|
||||
clear_graphs(graph_cache_);
|
||||
}
|
||||
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(node_count_);
|
||||
@ -281,6 +282,11 @@ void CommandEncoder::commit() {
|
||||
}
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// TODO smarter cache policy
|
||||
if (graph_cache_.size() > cuda_graph_cache_size()) {
|
||||
clear_graphs(graph_cache_);
|
||||
}
|
||||
|
||||
// Reset state
|
||||
node_count_ = 0;
|
||||
graph_node_count_ = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user