mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 15:58:14 +08:00
Detect cache thrashing in LRUCache (#2600)
* Detect cache thrashing in LRUCache * Do not check cache thrashing in tests
This commit is contained in:
@@ -47,7 +47,7 @@ auto& conv_cache() {
|
||||
std::pair<
|
||||
cudnnBackendDescriptorType_t,
|
||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||
cache(/* capacity */ 128);
|
||||
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
|
@@ -27,13 +27,6 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
int cuda_graph_cache_size() {
|
||||
static int cache_size = []() {
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
|
||||
}();
|
||||
return cache_size;
|
||||
}
|
||||
|
||||
bool use_cuda_graphs() {
|
||||
static bool use_graphs = []() {
|
||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
@@ -203,7 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d),
|
||||
stream_(d),
|
||||
graph_(d),
|
||||
graph_cache_(cuda_graph_cache_size()) {}
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
|
@@ -2,11 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <
|
||||
@@ -27,6 +31,14 @@ class LRUCache {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize with capacity read from |env_name|.
|
||||
LRUCache(const char* env_name, int default_capacity)
|
||||
: LRUCache(env::get_var(env_name, default_capacity)) {
|
||||
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
|
||||
env_name_ = env_name;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return map_.size();
|
||||
}
|
||||
@@ -76,6 +88,14 @@ class LRUCache {
|
||||
return {it->second, false};
|
||||
}
|
||||
|
||||
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Cache thrashing is happening, please set the environment variable "
|
||||
"{} to a larger value than {} to fix degraded performance.",
|
||||
env_name_,
|
||||
capacity_));
|
||||
}
|
||||
|
||||
vlist_.emplace_front(key, std::forward<U>(value));
|
||||
map_[key] = vlist_.begin();
|
||||
|
||||
@@ -106,6 +126,9 @@ class LRUCache {
|
||||
}
|
||||
}
|
||||
|
||||
const char* env_name_{nullptr};
|
||||
size_t cache_misses_{0};
|
||||
|
||||
list_type vlist_;
|
||||
map_type map_;
|
||||
size_t capacity_;
|
||||
|
@@ -4,7 +4,6 @@
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
@@ -5,6 +5,9 @@ import os
|
||||
# Use regular fp32 precision for tests
|
||||
os.environ["MLX_ENABLE_TF32"] = "0"
|
||||
|
||||
# Do not abort on cache thrashing
|
||||
os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0"
|
||||
|
||||
import platform
|
||||
import unittest
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
|
Reference in New Issue
Block a user