diff --git a/mlx/backend/cuda/lru_cache.h b/mlx/backend/cuda/lru_cache.h index 7294f2477..c8df2fa93 100644 --- a/mlx/backend/cuda/lru_cache.h +++ b/mlx/backend/cuda/lru_cache.h @@ -16,8 +16,9 @@ class LRUCache { public: using value_type = std::pair; using list_type = std::list; - using list_iter = typename list_type::iterator; - using map_type = M; + using iterator = typename list_type::iterator; + using const_iterator = typename list_type::const_iterator; + using map_type = M; explicit LRUCache(size_t capacity) : capacity_(capacity) {} @@ -36,16 +37,16 @@ class LRUCache { trim(); } - auto begin() { + iterator begin() { return vlist_.begin(); } - auto begin() const { + const_iterator begin() const { return vlist_.begin(); } - auto end() { + iterator end() { return vlist_.end(); } - auto end() const { + const_iterator end() const { return vlist_.end(); } @@ -54,7 +55,7 @@ class LRUCache { vlist_.clear(); } - list_iter find(const K& key) { + iterator find(const K& key) { auto it = map_.find(key); if (it == map_.end()) return end(); @@ -62,14 +63,15 @@ class LRUCache { return it->second; } - std::pair emplace(const K& key, V value) { + template + std::pair emplace(const K& key, U&& value) { auto it = map_.find(key); if (it != map_.end()) { vlist_.splice(vlist_.begin(), vlist_, it->second); return {it->second, false}; } - vlist_.emplace_front(key, std::move(value)); + vlist_.emplace_front(key, std::forward(value)); map_[key] = vlist_.begin(); trim(); @@ -77,7 +79,7 @@ class LRUCache { return {vlist_.begin(), true}; } - list_iter erase(list_iter pos) { + iterator erase(iterator pos) { map_.erase(pos->first); return vlist_.erase(pos); }