mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
50cc09887f
...
787c0d90cd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
787c0d90cd | ||
|
|
e8b604a6a3 |
@@ -47,7 +47,7 @@ auto& conv_cache() {
|
|||||||
std::pair<
|
std::pair<
|
||||||
cudnnBackendDescriptorType_t,
|
cudnnBackendDescriptorType_t,
|
||||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||||
cache(/* capacity */ 128);
|
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,13 +27,6 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int cuda_graph_cache_size() {
|
|
||||||
static int cache_size = []() {
|
|
||||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
|
|
||||||
}();
|
|
||||||
return cache_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool use_cuda_graphs() {
|
bool use_cuda_graphs() {
|
||||||
static bool use_graphs = []() {
|
static bool use_graphs = []() {
|
||||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
@@ -203,7 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
|
|||||||
: device_(d),
|
: device_(d),
|
||||||
stream_(d),
|
stream_(d),
|
||||||
graph_(d),
|
graph_(d),
|
||||||
graph_cache_(cuda_graph_cache_size()) {}
|
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
worker_.add_task(std::move(task));
|
worker_.add_task(std::move(task));
|
||||||
|
|||||||
@@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@@ -27,6 +31,14 @@ class LRUCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize with capacity read from |env_name|.
|
||||||
|
LRUCache(const char* env_name, int default_capacity)
|
||||||
|
: LRUCache(env::get_var(env_name, default_capacity)) {
|
||||||
|
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
|
||||||
|
env_name_ = env_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t size() const {
|
size_t size() const {
|
||||||
return map_.size();
|
return map_.size();
|
||||||
}
|
}
|
||||||
@@ -76,6 +88,14 @@ class LRUCache {
|
|||||||
return {it->second, false};
|
return {it->second, false};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Cache thrashing is happening, please set the environment variable "
|
||||||
|
"{} to a larger value than {} to fix degraded performance.",
|
||||||
|
env_name_,
|
||||||
|
capacity_));
|
||||||
|
}
|
||||||
|
|
||||||
vlist_.emplace_front(key, std::forward<U>(value));
|
vlist_.emplace_front(key, std::forward<U>(value));
|
||||||
map_[key] = vlist_.begin();
|
map_[key] = vlist_.begin();
|
||||||
|
|
||||||
@@ -106,6 +126,9 @@ class LRUCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* env_name_{nullptr};
|
||||||
|
size_t cache_misses_{0};
|
||||||
|
|
||||||
list_type vlist_;
|
list_type vlist_;
|
||||||
map_type map_;
|
map_type map_;
|
||||||
size_t capacity_;
|
size_t capacity_;
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
|||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* load_default_library(MTL::Device* device) {
|
MTL::Library* load_default_library(MTL::Device* device) {
|
||||||
NS::Error* error[4];
|
NS::Error* error[5];
|
||||||
MTL::Library* lib;
|
MTL::Library* lib;
|
||||||
// First try the colocated mlx.metallib
|
// First try the colocated mlx.metallib
|
||||||
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
||||||
@@ -127,12 +127,19 @@ MTL::Library* load_default_library(MTL::Device* device) {
|
|||||||
return lib;
|
return lib;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try lo load resources from Framework resources if SwiftPM wrapped as a
|
||||||
|
// dynamic framework.
|
||||||
|
std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default");
|
||||||
|
if (lib) {
|
||||||
|
return lib;
|
||||||
|
}
|
||||||
|
|
||||||
// Finally try default_mtllib_path
|
// Finally try default_mtllib_path
|
||||||
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
|
std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);
|
||||||
if (!lib) {
|
if (!lib) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Failed to load the default metallib. ";
|
msg << "Failed to load the default metallib. ";
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
if (error[i] != nullptr) {
|
if (error[i] != nullptr) {
|
||||||
msg << error[i]->localizedDescription()->utf8String() << " ";
|
msg << error[i]->localizedDescription()->utf8String() << " ";
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import os
|
|||||||
# Use regular fp32 precision for tests
|
# Use regular fp32 precision for tests
|
||||||
os.environ["MLX_ENABLE_TF32"] = "0"
|
os.environ["MLX_ENABLE_TF32"] = "0"
|
||||||
|
|
||||||
|
# Do not abort on cache thrashing
|
||||||
|
os.environ["MLX_ENABLE_CACHE_THRASHING_CHECK"] = "0"
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, Callable, List, Tuple, Union
|
from typing import Any, Callable, List, Tuple, Union
|
||||||
|
|||||||
Reference in New Issue
Block a user