diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 930e570e2..95aeb1cc9 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -55,7 +55,10 @@ std::pair load_library_from_path( } #ifdef SWIFTPM_BUNDLE -MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { +MTL::Library* try_load_bundle( + MTL::Device* device, + NS::URL* url, + const std::string& lib_name) { std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" + SWIFTPM_BUNDLE + ".bundle"; auto bundle = NS::Bundle::alloc()->init( @@ -63,8 +66,8 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { if (bundle != nullptr) { std::string resource_path = std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" + - "default.metallib"; - auto [lib, error] = load_library_from_path(device, resource_path.c_str()); + lib_name + ".metallib" auto [lib, error] = + load_library_from_path(device, resource_path.c_str()); if (lib) { return lib; } @@ -73,51 +76,124 @@ MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) { } #endif +// Firstly, search for the metallib in the same path as this binary +std::pair load_colocated_library( + MTL::Device* device, + const std::string& lib_name) { + std::string lib_path = get_colocated_mtllib_path(lib_name); + if (lib_path.size() != 0) { + return load_library_from_path(device, lib_path.c_str()); + } + return {nullptr, nullptr}; +} + +std::pair load_swiftpm_library( + MTL::Device* device, + const std::string& lib_name) { +#ifdef SWIFTPM_BUNDLE + MTL::Library* library = + try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name); + if (library != nullptr) { + return {library, nullptr}; + } + auto bundles = NS::Bundle::allBundles(); + for (int i = 0, c = (int)bundles->count(); i < c; i++) { + auto bundle = reinterpret_cast(bundles->object(i)); + library = try_load_bundle(device, bundle->resourceURL()); + if (library != nullptr) { + return {library, nullptr}; + } + } +#endif + return {nullptr, nullptr}; +} + +MTL::Library* load_default_library(MTL::Device* device) { + NS::Error *error1, *error2, *error3; + MTL::Library* lib; + // First try the colocated mlx.metallib + std::tie(lib, error1) = load_colocated_library(device, "mlx"); + if (lib) { + return lib; + } + + // Then try default.metallib in a SwiftPM bundle if we have one + std::tie(lib, error2) = load_swiftpm_library(device, "default"); + if (lib) { + return lib; + } + + // Finally try default_mtllib_path + std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path); + if (!lib) { + std::ostringstream msg; + msg << "Failed to load the default metallib. "; + if (error1 != nullptr) { + msg << error1->localizedDescription()->utf8String() << " "; + } + if (error2 != nullptr) { + msg << error2->localizedDescription()->utf8String() << " "; + } + if (error3 != nullptr) { + msg << error3->localizedDescription()->utf8String() << " "; + } + throw std::runtime_error(msg.str()); + } + return lib; +} + MTL::Library* load_library( MTL::Device* device, - const std::string& lib_name = "mlx", - const char* lib_path = default_mtllib_path) { - // Firstly, search for the metallib in the same path as this binary - std::string first_path = get_colocated_mtllib_path(lib_name); - if (first_path.size() != 0) { - auto [lib, error] = load_library_from_path(device, first_path.c_str()); + const std::string& lib_name, + const std::string& lib_path) { + // We have been given a path that ends in metallib so try to load it + if (lib_path.size() > 9 && + std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) { + auto [lib, error] = load_library_from_path(device, lib_path.c_str()); + if (!lib) { + std::ostringstream msg; + msg << "Failed to load the metallib from <" << lib_path << "> with error " + << error->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } + } + + // We have been given a path so try to load from lib_path / lib_name.metallib + if (lib_path.size() > 0) { + std::string full_path = lib_path + "/" + lib_name + ".metallib"; + auto [lib, error] = load_library_from_path(device, full_path.c_str()); + if (!lib) { + std::ostringstream msg; + msg << "Failed to load the metallib from <" << full_path + << "> with error " << error->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } + } + + // Try to load the colocated library + { + auto [lib, error] = load_colocated_library(device, lib_name); if (lib) { return lib; } } -#ifdef SWIFTPM_BUNDLE - // try to load from a swiftpm resource bundle -- scan the available bundles to - // find one that contains the named bundle + // Try to load the library from swiftpm { - MTL::Library* library = - try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL()); - if (library != nullptr) { - return library; - } - auto bundles = NS::Bundle::allBundles(); - for (int i = 0, c = (int)bundles->count(); i < c; i++) { - auto bundle = reinterpret_cast(bundles->object(i)); - library = try_load_bundle(device, bundle->resourceURL()); - if (library != nullptr) { - return library; - } + auto [lib, error] = load_swiftpm_library(device, lib_name); + if (lib) { + return lib; } } -#endif - // Couldn't find it so let's load it from default_mtllib_path - { - auto [lib, error] = load_library_from_path(device, lib_path); - if (!lib) { - std::ostringstream msg; - msg << error->localizedDescription()->utf8String() << "\n" - << "Failed to load device library from <" << lib_path << ">" - << " or <" << first_path << ">."; - throw std::runtime_error(msg.str()); - } - return lib; - } + std::ostringstream msg; + msg << "Failed to load the metallib " << lib_name << ".metallib. " + << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name) + << ">"; +#ifdef SWIFTPM_BUNDLE + msg << " and from the Swift PM bundle."; +#endif + throw std::runtime_error(msg.str()); } } // namespace @@ -210,7 +286,7 @@ void CommandEncoder::barrier() { Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); - library_map_ = {{"mlx", load_library(device_)}}; + library_map_ = {{"mlx", load_default_library(device_)}}; arch_ = std::string(device_->architecture()->name()->utf8String()); auto arch = arch_.back(); switch (arch) { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 1fe7cf76f..bb0e93147 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -189,15 +189,7 @@ class Device { void register_library( const std::string& lib_name, - const std::string& lib_path); - - // Note, this should remain in the header so that it is not dynamically - // linked - void register_library(const std::string& lib_name) { - if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - register_library(lib_name, get_colocated_mtllib_path(lib_name)); - } - } + const std::string& lib_path = ""); MTL::Library* get_library( const std::string& name,