diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index ebc3cc77f..549d04f6b 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION @@ -35,6 +36,16 @@ auto get_metal_version() { return metal_version_; } +static fs::path get_dylib_directory() { + Dl_info info{}; + + if (dladdr(reinterpret_cast(default_mtllib_path), &info) && info.dli_fname) { + fs::path libFile(info.dli_fname); + return libFile.parent_path(); + } + return {}; +} + auto load_device() { auto devices = MTL::CopyAllDevices(); auto device = static_cast(devices->object(0)) @@ -115,7 +126,7 @@ std::pair load_swiftpm_library( } MTL::Library* load_default_library(MTL::Device* device) { - NS::Error* error[4]; + NS::Error* error[5]; MTL::Library* lib; // First try the colocated mlx.metallib std::tie(lib, error[0]) = load_colocated_library(device, "mlx"); @@ -136,10 +147,25 @@ MTL::Library* load_default_library(MTL::Device* device) { // Finally try default_mtllib_path std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path); + if (lib) { + return lib; + } + + { + auto dir = get_dylib_directory(); + if (!dir.empty()) { + auto dylib_path = (dir / default_mtllib_path).string(); + std::tie(lib, error[4]) = load_library_from_path(device, dylib_path.c_str()); + if (lib) { + return lib; + } + } + } + if (!lib) { std::ostringstream msg; msg << "Failed to load the default metallib. "; - for (int i = 0; i < 4; i++) { + for (int i = 0; i < 5; i++) { if (error[i] != nullptr) { msg << error[i]->localizedDescription()->utf8String() << " "; }