From 38c1e720c262aa44e08edca30903fdcf81c17c05 Mon Sep 17 00:00:00 2001 From: hdeng-apple Date: Thu, 24 Apr 2025 00:53:13 +0800 Subject: [PATCH] Search mlx.metallib in macOS framework "Resources" dir (#2061) --------- Co-authored-by: Angelos Katharopoulos --- mlx/backend/metal/device.cpp | 48 ++++++++++++++++++++++-------------- mlx/backend/metal/device.h | 14 ++++------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 95aeb1cc9..43f82893b 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include #include @@ -15,6 +16,8 @@ #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" +namespace fs = std::filesystem; + namespace mlx::core::metal { namespace { @@ -79,12 +82,18 @@ MTL::Library* try_load_bundle( // Firstly, search for the metallib in the same path as this binary std::pair load_colocated_library( MTL::Device* device, - const std::string& lib_name) { - std::string lib_path = get_colocated_mtllib_path(lib_name); - if (lib_path.size() != 0) { - return load_library_from_path(device, lib_path.c_str()); + const std::string& relative_path) { + std::string binary_dir = get_binary_directory(); + if (binary_dir.size() == 0) { + return {nullptr, nullptr}; } - return {nullptr, nullptr}; + + auto path = fs::path(binary_dir) / relative_path; + if (!path.has_extension()) { + path.replace_extension(".metallib"); + } + + return load_library_from_path(device, path.c_str()); } std::pair load_swiftpm_library( @@ -109,33 +118,34 @@ std::pair load_swiftpm_library( } MTL::Library* load_default_library(MTL::Device* device) { - NS::Error *error1, *error2, *error3; + NS::Error* error[4]; MTL::Library* lib; // First try the colocated mlx.metallib - std::tie(lib, error1) = load_colocated_library(device, "mlx"); + std::tie(lib, error[0]) = load_colocated_library(device, "mlx"); + if (lib) { + return lib; + } + + std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx"); if (lib) { return lib; } // Then try default.metallib in a SwiftPM bundle if we have one - std::tie(lib, error2) = load_swiftpm_library(device, "default"); + std::tie(lib, error[2]) = load_swiftpm_library(device, "default"); if (lib) { return lib; } // Finally try default_mtllib_path - std::tie(lib, error3) = load_library_from_path(device, default_mtllib_path); + std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path); if (!lib) { std::ostringstream msg; msg << "Failed to load the default metallib. "; - if (error1 != nullptr) { - msg << error1->localizedDescription()->utf8String() << " "; - } - if (error2 != nullptr) { - msg << error2->localizedDescription()->utf8String() << " "; - } - if (error3 != nullptr) { - msg << error3->localizedDescription()->utf8String() << " "; + for (int i = 0; i < 4; i++) { + if (error[i] != nullptr) { + msg << error[i]->localizedDescription()->utf8String() << " "; + } } throw std::runtime_error(msg.str()); } @@ -188,8 +198,8 @@ MTL::Library* load_library( std::ostringstream msg; msg << "Failed to load the metallib " << lib_name << ".metallib. " - << "We attempted to load it from <" << get_colocated_mtllib_path(lib_name) - << ">"; + << "We attempted to load it from <" << get_binary_directory() << "/" + << lib_name << ".metallib" << ">"; #ifdef SWIFTPM_BUNDLE msg << " and from the Swift PM bundle."; #endif diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bb0e93147..d60635e39 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -21,18 +21,14 @@ namespace mlx::core::metal { // Note, this function must be left inline in a header so that it is not // dynamically linked. -inline std::string get_colocated_mtllib_path(const std::string& lib_name) { +inline std::string get_binary_directory() { Dl_info info; - std::string mtllib_path; - std::string lib_ext = lib_name + ".metallib"; - - int success = dladdr((void*)get_colocated_mtllib_path, &info); + std::string directory; + int success = dladdr((void*)get_binary_directory, &info); if (success) { - auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; - mtllib_path = mtllib.c_str(); + directory = fs::path(info.dli_fname).remove_filename().c_str(); } - - return mtllib_path; + return directory; } using MTLFCList =