Fix unintuitive metal kernel caching (#2242)

* Fix unintuitive metal kernel caching

* alternative solution
This commit is contained in:
Awni Hannun
2025-06-06 20:08:15 -07:00
committed by GitHub
parent 2e8cf0b450
commit 1ca616844b
13 changed files with 713 additions and 593 deletions

View File

@@ -295,7 +295,7 @@ void CommandEncoder::barrier() {
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_default_library(device_)}};
default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {
@@ -326,11 +326,11 @@ Device::Device() {
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
for (auto& [l, kernel_map] : library_kernels_) {
l->release();
for (auto& [_, k] : kernel_map) {
k->release();
}
}
stream_map_.clear();
device_->release();
@@ -474,13 +474,24 @@ CommandEncoder& Device::get_command_encoder(int index) {
return *stream.encoder;
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
library_map_.insert({lib_name, new_lib});
MTL::Library* Device::get_library(
const std::string& name,
const std::string& path /* = "" */) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto new_lib = load_library(device_, name, path.c_str());
library_map_.insert({name, new_lib});
return new_lib;
}
MTL::Library* Device::build_library_(const std::string& source_string) {
@@ -649,6 +660,19 @@ MTL::Library* Device::get_library(
return mtl_lib;
}
void Device::clear_library(const std::string& name) {
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
auto kernel_map_it = library_kernels_.find(it->second);
for (auto& [_, kernel] : kernel_map_it->second) {
kernel->release();
}
library_kernels_.erase(kernel_map_it);
it->second->release();
library_map_.erase(it);
}
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
@@ -679,6 +703,7 @@ MTL::ComputePipelineState* Device::get_kernel_(
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
@@ -713,6 +738,7 @@ MTL::ComputePipelineState* Device::get_kernel(
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
auto& kernel_map_ = library_kernels_[mtl_lib];
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
@@ -722,23 +748,11 @@ MTL::ComputePipelineState* Device::get_kernel(
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
return get_kernel(
base_name, default_library_, hash_name, func_consts, linked_functions);
}
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {