diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 95aeb1cc9..f909bc4e6 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. +#include #include #include @@ -17,6 +18,22 @@ namespace mlx::core::metal { +std::string search_colocated_mtllib( + const fs::path& dli_path, + const std::string& lib_name) { + auto dir_name = dli_path.parent_path(); + std::array search_list{ + dir_name / lib_name, + dir_name / "Resources" / lib_name // in macOS framework + }; + for (const auto& lib_path : search_list) { + if (fs::exists(lib_path)) { + return lib_path.c_str(); + } + } + return ""; +} + namespace { constexpr const char* default_mtllib_path = METAL_PATH; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bb0e93147..66a5eef32 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -19,6 +19,10 @@ namespace fs = std::filesystem; namespace mlx::core::metal { +std::string search_colocated_mtllib( + const fs::path& dli_path, + const std::string& lib_name); + // Note, this function must be left inline in a header so that it is not // dynamically linked. inline std::string get_colocated_mtllib_path(const std::string& lib_name) { @@ -28,8 +32,7 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) { int success = dladdr((void*)get_colocated_mtllib_path, &info); if (success) { - auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; - mtllib_path = mtllib.c_str(); + mtllib_path = search_colocated_mtllib(fs::path(info.dli_fname), lib_ext); } return mtllib_path;