Detect cache thrashing in LRUCache (#2600)

* Detect cache thrashing in LRUCache

* Do not check cache thrashing in tests
This commit is contained in:
Cheng
2025-09-19 09:12:14 +09:00
committed by GitHub
parent e8b604a6a3
commit 787c0d90cd
5 changed files with 28 additions and 10 deletions

View File

@@ -47,7 +47,7 @@ auto& conv_cache() {
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128);
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache;
}

View File

@@ -27,13 +27,6 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
}
}
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
}();
return cache_size;
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
@@ -203,7 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
graph_cache_(cuda_graph_cache_size()) {}
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));

View File

@@ -2,11 +2,15 @@
#pragma once
#include "mlx/utils.h"
#include <cstring>
#include <list>
#include <unordered_map>
#include <utility>
#include <fmt/format.h>
namespace mlx::core {
template <
@@ -27,6 +31,14 @@ class LRUCache {
}
}
// Initialize with capacity read from |env_name|.
LRUCache(const char* env_name, int default_capacity)
: LRUCache(env::get_var(env_name, default_capacity)) {
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
env_name_ = env_name;
}
}
size_t size() const {
return map_.size();
}
@@ -76,6 +88,14 @@ class LRUCache {
return {it->second, false};
}
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
throw std::runtime_error(fmt::format(
"Cache thrashing is happening, please set the environment variable "
"{} to a larger value than {} to fix degraded performance.",
env_name_,
capacity_));
}
vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin();
@@ -106,6 +126,9 @@ class LRUCache {
}
}
const char* env_name_{nullptr};
size_t cache_misses_{0};
list_type vlist_;
map_type map_;
size_t capacity_;

View File

@@ -4,7 +4,6 @@
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"

View File

@@ -5,6 +5,9 @@ import os
# Use regular fp32 precision for tests
os.environ["MLX_ENABLE_TF32"] = "0"
# Do not abort on cache thrashing
os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0"
import platform
import unittest
from typing import Any, Callable, List, Tuple, Union