Use LRU cache for cuda graph (#2448)

* Use LRU cache for cuda graph

* Remove unused destructor
This commit is contained in:
Cheng
2025-08-02 21:28:57 +09:00
committed by GitHub
parent 8831064493
commit aaf78f4c6b
5 changed files with 63 additions and 25 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h"
@@ -31,7 +32,6 @@ class CommandEncoder {
};
explicit CommandEncoder(Device& d);
~CommandEncoder();
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -126,7 +126,7 @@ class CommandEncoder {
std::string graph_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;