From 787c0d90cdbd9b7a444dde926d779c335a6bc752 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 19 Sep 2025 09:12:14 +0900 Subject: [PATCH] Detect cache thrashing in LRUCache (#2600) * Detect cache thrashing in LRUCache * Do not check cache thrashing in tests --- mlx/backend/cuda/conv.cpp | 2 +- mlx/backend/cuda/device.cpp | 9 +------- mlx/backend/cuda/lru_cache.h | 23 +++++++++++++++++++ .../cuda/scaled_dot_product_attention.cu | 1 - python/tests/mlx_tests.py | 3 +++ 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 63188fbc8..a5bc8e41a 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -47,7 +47,7 @@ auto& conv_cache() { std::pair< cudnnBackendDescriptorType_t, std::optional>> - cache(/* capacity */ 128); + cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128); return cache; } diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index d7b9a0328..bf0946a7b 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -27,13 +27,6 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) { } } -int cuda_graph_cache_size() { - static int cache_size = []() { - return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400); - }(); - return cache_size; -} - bool use_cuda_graphs() { static bool use_graphs = []() { return env::get_var("MLX_USE_CUDA_GRAPHS", true); @@ -203,7 +196,7 @@ CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d), graph_(d), - graph_cache_(cuda_graph_cache_size()) {} + graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); diff --git a/mlx/backend/cuda/lru_cache.h b/mlx/backend/cuda/lru_cache.h index 79fca0ae7..dc8325fcd 100644 --- a/mlx/backend/cuda/lru_cache.h +++ b/mlx/backend/cuda/lru_cache.h @@ -2,11 +2,15 @@ #pragma once +#include "mlx/utils.h" + #include #include #include #include +#include + namespace mlx::core { template < @@ -27,6 +31,14 @@ class LRUCache { } } + // Initialize with capacity read from |env_name|. + LRUCache(const char* env_name, int default_capacity) + : LRUCache(env::get_var(env_name, default_capacity)) { + if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) { + env_name_ = env_name; + } + } + size_t size() const { return map_.size(); } @@ -76,6 +88,14 @@ class LRUCache { return {it->second, false}; } + if (env_name_ && ++cache_misses_ > 2 * capacity_) { + throw std::runtime_error(fmt::format( + "Cache thrashing is happening, please set the environment variable " + "{} to a larger value than {} to fix degraded performance.", + env_name_, + capacity_)); + } + vlist_.emplace_front(key, std::forward(value)); map_[key] = vlist_.begin(); @@ -106,6 +126,9 @@ class LRUCache { } } + const char* env_name_{nullptr}; + size_t cache_misses_{0}; + list_type vlist_; map_type map_; size_t capacity_; diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 2095bdb43..7d5437ef4 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -4,7 +4,6 @@ #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index bc197b673..c344e7c86 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -5,6 +5,9 @@ import os # Use regular fp32 precision for tests os.environ["MLX_ENABLE_TF32"] = "0" +# Do not abort on cache thrashing +os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0" + import platform import unittest from typing import Any, Callable, List, Tuple, Union