[CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329)

This commit is contained in:
Cheng 2025-07-06 00:33:29 +09:00 committed by GitHub
parent 8917022deb
commit 0e0d9ac522
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,12 @@ namespace mlx::core {
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
constexpr int max_graph_cache_size = 100;
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
}();
return cache_size;
}
namespace cu {
@ -252,10 +257,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
}
// TODO smarter cache policy
if (graph_cache_.size() > max_graph_cache_size) {
clear_graphs(graph_cache_);
}
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
@ -281,6 +282,11 @@ void CommandEncoder::commit() {
}
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// TODO smarter cache policy
if (graph_cache_.size() > cuda_graph_cache_size()) {
clear_graphs(graph_cache_);
}
// Reset state
node_count_ = 0;
graph_node_count_ = 0;