diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 4129563af..638d68727 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -15,7 +15,12 @@ namespace mlx::core { // This should be less than 255 constexpr int default_max_nodes_per_graph = 20; -constexpr int max_graph_cache_size = 100; +int cuda_graph_cache_size() { + static int cache_size = []() { + return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); + }(); + return cache_size; +} namespace cu { @@ -252,10 +257,6 @@ void CommandEncoder::commit() { CHECK_CUDA_ERROR(cudaGraphAddDependencies( graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); } - // TODO smarter cache policy - if (graph_cache_.size() > max_graph_cache_size) { - clear_graphs(graph_cache_); - } graph_key_ += "."; graph_key_ += std::to_string(node_count_); @@ -281,6 +282,11 @@ void CommandEncoder::commit() { } CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); + // TODO smarter cache policy + if (graph_cache_.size() > cuda_graph_cache_size()) { + clear_graphs(graph_cache_); + } + // Reset state node_count_ = 0; graph_node_count_ = 0;