mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 08:18:30 +08:00
Detect cache thrashing in LRUCache (#2600)
* Detect cache thrashing in LRUCache * Do not check cache thrashing in tests
This commit is contained in:
@@ -47,7 +47,7 @@ auto& conv_cache() {
|
|||||||
std::pair<
|
std::pair<
|
||||||
cudnnBackendDescriptorType_t,
|
cudnnBackendDescriptorType_t,
|
||||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||||
cache(/* capacity */ 128);
|
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -27,13 +27,6 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int cuda_graph_cache_size() {
|
|
||||||
static int cache_size = []() {
|
|
||||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
|
|
||||||
}();
|
|
||||||
return cache_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool use_cuda_graphs() {
|
bool use_cuda_graphs() {
|
||||||
static bool use_graphs = []() {
|
static bool use_graphs = []() {
|
||||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
@@ -203,7 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
|
|||||||
: device_(d),
|
: device_(d),
|
||||||
stream_(d),
|
stream_(d),
|
||||||
graph_(d),
|
graph_(d),
|
||||||
graph_cache_(cuda_graph_cache_size()) {}
|
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
worker_.add_task(std::move(task));
|
worker_.add_task(std::move(task));
|
||||||
|
@@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@@ -27,6 +31,14 @@ class LRUCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize with capacity read from |env_name|.
|
||||||
|
LRUCache(const char* env_name, int default_capacity)
|
||||||
|
: LRUCache(env::get_var(env_name, default_capacity)) {
|
||||||
|
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
|
||||||
|
env_name_ = env_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t size() const {
|
size_t size() const {
|
||||||
return map_.size();
|
return map_.size();
|
||||||
}
|
}
|
||||||
@@ -76,6 +88,14 @@ class LRUCache {
|
|||||||
return {it->second, false};
|
return {it->second, false};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Cache thrashing is happening, please set the environment variable "
|
||||||
|
"{} to a larger value than {} to fix degraded performance.",
|
||||||
|
env_name_,
|
||||||
|
capacity_));
|
||||||
|
}
|
||||||
|
|
||||||
vlist_.emplace_front(key, std::forward<U>(value));
|
vlist_.emplace_front(key, std::forward<U>(value));
|
||||||
map_[key] = vlist_.begin();
|
map_[key] = vlist_.begin();
|
||||||
|
|
||||||
@@ -106,6 +126,9 @@ class LRUCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* env_name_{nullptr};
|
||||||
|
size_t cache_misses_{0};
|
||||||
|
|
||||||
list_type vlist_;
|
list_type vlist_;
|
||||||
map_type map_;
|
map_type map_;
|
||||||
size_t capacity_;
|
size_t capacity_;
|
||||||
|
@@ -4,7 +4,6 @@
|
|||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
@@ -5,6 +5,9 @@ import os
|
|||||||
# Use regular fp32 precision for tests
|
# Use regular fp32 precision for tests
|
||||||
os.environ["MLX_ENABLE_TF32"] = "0"
|
os.environ["MLX_ENABLE_TF32"] = "0"
|
||||||
|
|
||||||
|
# Do not abort on cache thrashing
|
||||||
|
os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0"
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Callable, List, Tuple, Union
|
from typing import Any, Callable, List, Tuple, Union
|
||||||
|
Reference in New Issue
Block a user