mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-08 20:31:13 +08:00
[CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329)
This commit is contained in:
parent
8917022deb
commit
0e0d9ac522
@ -15,7 +15,12 @@ namespace mlx::core {
|
|||||||
// This should be less than 255
|
// This should be less than 255
|
||||||
constexpr int default_max_nodes_per_graph = 20;
|
constexpr int default_max_nodes_per_graph = 20;
|
||||||
|
|
||||||
constexpr int max_graph_cache_size = 100;
|
int cuda_graph_cache_size() {
|
||||||
|
static int cache_size = []() {
|
||||||
|
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||||
|
}();
|
||||||
|
return cache_size;
|
||||||
|
}
|
||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
@ -252,10 +257,6 @@ void CommandEncoder::commit() {
|
|||||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||||
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
||||||
}
|
}
|
||||||
// TODO smarter cache policy
|
|
||||||
if (graph_cache_.size() > max_graph_cache_size) {
|
|
||||||
clear_graphs(graph_cache_);
|
|
||||||
}
|
|
||||||
|
|
||||||
graph_key_ += ".";
|
graph_key_ += ".";
|
||||||
graph_key_ += std::to_string(node_count_);
|
graph_key_ += std::to_string(node_count_);
|
||||||
@ -281,6 +282,11 @@ void CommandEncoder::commit() {
|
|||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
|
||||||
|
// TODO smarter cache policy
|
||||||
|
if (graph_cache_.size() > cuda_graph_cache_size()) {
|
||||||
|
clear_graphs(graph_cache_);
|
||||||
|
}
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
node_count_ = 0;
|
node_count_ = 0;
|
||||||
graph_node_count_ = 0;
|
graph_node_count_ = 0;
|
||||||
|
Loading…
Reference in New Issue
Block a user