From 860d3a50d7b9d2c1b3c9ff24e524b0523eb68181 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 26 Aug 2024 09:18:50 -0700 Subject: [PATCH] fix extension metal library finding (#1361) --- mlx/backend/metal/device.cpp | 24 ------------------------ mlx/backend/metal/device.h | 28 +++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 2ab0fa960..8dfcf81e5 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -1,8 +1,6 @@ // Copyright © 2023-2024 Apple Inc. -#include #include -#include #include #include @@ -16,8 +14,6 @@ #include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/utils.h" -namespace fs = std::filesystem; - namespace mlx::core::metal { namespace { @@ -38,20 +34,6 @@ constexpr auto get_metal_version() { #endif } -std::string get_colocated_mtllib_path(const std::string& lib_name) { - Dl_info info; - std::string mtllib_path; - std::string lib_ext = lib_name + ".metallib"; - - int success = dladdr((void*)get_colocated_mtllib_path, &info); - if (success) { - auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; - mtllib_path = mtllib.c_str(); - } - - return mtllib_path; -} - auto load_device() { auto devices = MTL::CopyAllDevices(); auto device = static_cast(devices->object(0)) @@ -311,12 +293,6 @@ void Device::register_library( } } -void Device::register_library(const std::string& lib_name) { - if (auto it = library_map_.find(lib_name); it == library_map_.end()) { - register_library(lib_name, get_colocated_mtllib_path(lib_name)); - } -} - MTL::Library* Device::get_library_cache_(const std::string& lib_name) { // Search for cached metal lib MTL::Library* mtl_lib; diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index a28dd832e..2841e8103 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include #include #include @@ -12,8 +14,26 @@ #include "mlx/array.h" #include "mlx/device.h" +namespace fs = std::filesystem; + namespace mlx::core::metal { +// Note, this function must be left inline in a header so that it is not +// dynamically linked. +inline std::string get_colocated_mtllib_path(const std::string& lib_name) { + Dl_info info; + std::string mtllib_path; + std::string lib_ext = lib_name + ".metallib"; + + int success = dladdr((void*)get_colocated_mtllib_path, &info); + if (success) { + auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; + mtllib_path = mtllib.c_str(); + } + + return mtllib_path; +} + using MTLFCList = std::vector>; @@ -86,7 +106,13 @@ class Device { const std::string& lib_name, const std::string& lib_path); - void register_library(const std::string& lib_name); + // Note, this should remain in the header so that it is not dynamically + // linked + void register_library(const std::string& lib_name) { + if (auto it = library_map_.find(lib_name); it == library_map_.end()) { + register_library(lib_name, get_colocated_mtllib_path(lib_name)); + } + } MTL::Library* get_library(const std::string& name);