Compare commits

...

2 Commits

Author SHA1 Message Date
Cheng
787c0d90cd Detect cache thrashing in LRUCache (#2600)
* Detect cache thrashing in LRUCache

* Do not check cache thrashing in tests
2025-09-19 09:12:14 +09:00
Oleksandr Bilous
e8b604a6a3 fix: library loading for swift dynamic frameworks (#2568) 2025-09-18 13:54:59 -07:00
6 changed files with 38 additions and 13 deletions

View File

@@ -47,7 +47,7 @@ auto& conv_cache() {
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128);
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache;
}

View File

@@ -27,13 +27,6 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
}
}
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
}();
return cache_size;
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
@@ -203,7 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
graph_cache_(cuda_graph_cache_size()) {}
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));

View File

@@ -2,11 +2,15 @@
#pragma once
#include "mlx/utils.h"
#include <cstring>
#include <list>
#include <unordered_map>
#include <utility>
#include <fmt/format.h>
namespace mlx::core {
template <
@@ -27,6 +31,14 @@ class LRUCache {
}
}
// Initialize with capacity read from |env_name|.
LRUCache(const char* env_name, int default_capacity)
: LRUCache(env::get_var(env_name, default_capacity)) {
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
env_name_ = env_name;
}
}
size_t size() const {
return map_.size();
}
@@ -76,6 +88,14 @@ class LRUCache {
return {it->second, false};
}
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
throw std::runtime_error(fmt::format(
"Cache thrashing is happening, please set the environment variable "
"{} to a larger value than {} to fix degraded performance.",
env_name_,
capacity_));
}
vlist_.emplace_front(key, std::forward<U>(value));
map_[key] = vlist_.begin();
@@ -106,6 +126,9 @@ class LRUCache {
}
}
const char* env_name_{nullptr};
size_t cache_misses_{0};
list_type vlist_;
map_type map_;
size_t capacity_;

View File

@@ -4,7 +4,6 @@
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"

View File

@@ -108,7 +108,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error* error[4];
NS::Error* error[5];
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
@@ -127,12 +127,19 @@ MTL::Library* load_default_library(MTL::Device* device) {
return lib;
}
// Try lo load resources from Framework resources if SwiftPM wrapped as a
// dynamic framework.
std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
for (int i = 0; i < 4; i++) {
for (int i = 0; i < 5; i++) {
if (error[i] != nullptr) {
msg << error[i]->localizedDescription()->utf8String() << " ";
}

View File

@@ -5,6 +5,9 @@ import os
# Use regular fp32 precision for tests
os.environ["MLX_ENABLE_TF32"] = "0"
# Do not abort on cache thrashing
os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0"
import platform
import unittest
from typing import Any, Callable, List, Tuple, Union