Do not load the default lib if another is requested (#2055)

This commit is contained in:
Angelos Katharopoulos 2025-04-09 13:31:38 -07:00 committed by GitHub
parent e5d35aa187
commit 9ecefd56db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 115 additions and 47 deletions

View File

@ -55,7 +55,10 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
} }
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle"; SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init( auto bundle = NS::Bundle::alloc()->init(
@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
if (bundle != nullptr) { if (bundle != nullptr) {
std::string resource_path = std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
"default.metallib"; lib_name + ".metallib" auto [lib, error] =
auto [lib, error] = load_library_from_path(device, resource_path.c_str()); load_library_from_path(device, resource_path.c_str());
if (lib) { if (lib) {
return lib; return lib;
} }
@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
} }
#endif #endif
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name = "mlx",
const char* lib_path = default_mtllib_path) {
// Firstly, search for the metallib in the same path as this binary // Firstly, search for the metallib in the same path as this binary
std::string first_path = get_colocated_mtllib_path(lib_name); std::pair<MTL::Library*, NS::Error*> load_colocated_library(
if (first_path.size() != 0) { MTL::Device* device,
auto [lib, error] = load_library_from_path(device, first_path.c_str()); const std::string& lib_name) {
if (lib) { std::string lib_path = get_colocated_mtllib_path(lib_name);
return lib; if (lib_path.size() != 0) {
return load_library_from_path(device, lib_path.c_str());
} }
return {nullptr, nullptr};
} }
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE #ifdef SWIFTPM_BUNDLE
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
{
MTL::Library* library = MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) { if (library != nullptr) {
return library; return {library, nullptr};
} }
auto bundles = NS::Bundle::allBundles(); auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) { for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i)); auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL()); library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) { if (library != nullptr) {
return library; return {library, nullptr};
}
} }
} }
#endif #endif
return {nullptr, nullptr};
}
// Couldn't find it so let's load it from default_mtllib_path MTL::Library* load_default_library(MTL::Device* device) {
{ NS::Error *error1, *error2, *error3;
auto [lib, error] = load_library_from_path(device, lib_path); MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error1) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error2) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path);
if (!lib) { if (!lib) {
std::ostringstream msg; std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n" msg << "Failed to load the default metallib. ";
<< "Failed to load device library from <" << lib_path << ">" if (error1 != nullptr) {
<< " or <" << first_path << ">."; msg << error1->localizedDescription()->utf8String() << " ";
}
if (error2 != nullptr) {
msg << error2->localizedDescription()->utf8String() << " ";
}
if (error3 != nullptr) {
msg << error3->localizedDescription()->utf8String() << " ";
}
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
return lib; return lib;
} }
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name,
const std::string& lib_path) {
// We have been given a path that ends in metallib so try to load it
if (lib_path.size() > 9 &&
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
if (lib) {
return lib;
}
}
// Try to load the library from swiftpm
{
auto [lib, error] = load_swiftpm_library(device, lib_name);
if (lib) {
return lib;
}
}
std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_colocated_mtllib_path(lib_name)
<< ">";
#ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle.";
#endif
throw std::runtime_error(msg.str());
} }
} // namespace } // namespace
@ -210,7 +286,7 @@ void CommandEncoder::barrier() {
Device::Device() { Device::Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}}; library_map_ = {{"mlx", load_default_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {

View File

@ -189,15 +189,7 @@ class Device {
void register_library( void register_library(
const std::string& lib_name, const std::string& lib_name,
const std::string& lib_path); const std::string& lib_path = "");
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library( MTL::Library* get_library(
const std::string& name, const std::string& name,