mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use LRU cache for cuda graph (#2448)
* Use LRU cache for cuda graph * Remove unused destructor
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
@@ -31,7 +32,6 @@ class CommandEncoder {
|
||||
};
|
||||
|
||||
explicit CommandEncoder(Device& d);
|
||||
~CommandEncoder();
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
@@ -126,7 +126,7 @@ class CommandEncoder {
|
||||
std::string graph_key_;
|
||||
std::vector<GraphNode> concurrent_nodes_;
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
std::unordered_map<std::string, cudaGraphExec_t> graph_cache_;
|
||||
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||
std::vector<std::uintptr_t> active_deps_;
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
|
||||
Reference in New Issue
Block a user